Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
169b5306
"docs/vscode:/vscode.git/clone" did not exist on "c5cffcd0cdbba9273954b4fd1317137208ce564c"
Unverified
Commit
169b5306
authored
Oct 14, 2024
by
Tyler Michael Smith
Committed by
GitHub
Oct 15, 2024
Browse files
[Bugfix] Clean up some cruft in mamba.py (#9343)
parent
f0fe4fe8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
104 deletions
+11
-104
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+1
-1
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+10
-103
No files found.
docs/source/models/supported_models.rst
View file @
169b5306
...
@@ -155,7 +155,7 @@ Text Generation
...
@@ -155,7 +155,7 @@ Text Generation
* - :code:`MambaForCausalLM`
* - :code:`MambaForCausalLM`
- Mamba
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
-
✅︎
-
-
-
* - :code:`MiniCPMForCausalLM`
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- MiniCPM
...
...
vllm/model_executor/models/mamba.py
View file @
169b5306
# coding=utf-8
# coding=utf-8
"""PyTorch MAMBA model."""
"""PyTorch MAMBA model."""
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -10,7 +9,6 @@ from transformers import MambaConfig
...
@@ -10,7 +9,6 @@ from transformers import MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -39,13 +37,6 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
...
@@ -39,13 +37,6 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
@
dataclass
class
MambaCacheParams
:
is_prompt
:
bool
=
False
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
MambaMixer
(
nn
.
Module
):
class
MambaMixer
(
nn
.
Module
):
"""
"""
...
@@ -209,37 +200,6 @@ class MambaMixer(nn.Module):
...
@@ -209,37 +200,6 @@ class MambaMixer(nn.Module):
return
contextualized_states
return
contextualized_states
class
MambaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
hidden_act
=
config
.
hidden_act
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
MambaDecoderLayer
(
nn
.
Module
):
class
MambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -252,7 +212,6 @@ class MambaDecoderLayer(nn.Module):
...
@@ -252,7 +212,6 @@ class MambaDecoderLayer(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
self
.
feed_forward
=
MambaMLP
(
config
,
quant_config
=
quant_config
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
eps
=
config
.
layer_norm_epsilon
)
...
@@ -274,10 +233,6 @@ class MambaDecoderLayer(nn.Module):
...
@@ -274,10 +233,6 @@ class MambaDecoderLayer(nn.Module):
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
conv_state
,
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
conv_state
,
ssm_state
)
ssm_state
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -319,7 +274,6 @@ class MambaModel(nn.Module):
...
@@ -319,7 +274,6 @@ class MambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
...
@@ -346,26 +300,6 @@ class MambaModel(nn.Module):
...
@@ -346,26 +300,6 @@ class MambaModel(nn.Module):
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embeddings"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -416,8 +350,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -416,8 +350,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
input_ids
,
attn_metadata
,
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
attn_metadata
,
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
])
mamba_cache_tensors
[
1
])
return
hidden_states
return
hidden_states
...
@@ -457,38 +391,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -457,38 +391,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"A_log"
in
name
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
name
=
name
.
replace
(
"A_log"
,
"A"
)
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment