Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6b97293
Unverified
Commit
f6b97293
authored
Oct 21, 2024
by
Dhia Eddine Rhaiem
Committed by
GitHub
Oct 21, 2024
Browse files
[Model] FalconMamba Support (#9325)
parent
496e991d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
12 deletions
+35
-12
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-0
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+1
-1
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+0
-1
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+28
-10
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/source/models/supported_models.rst
View file @
f6b97293
...
@@ -87,6 +87,11 @@ Text Generation
...
@@ -87,6 +87,11 @@ Text Generation
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
-
-
- ✅︎
- ✅︎
* - :code:`FalconMambaForCausalLM`
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
* - :code:`GemmaForCausalLM`
* - :code:`GemmaForCausalLM`
- Gemma
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
f6b97293
...
@@ -10,7 +10,7 @@ from vllm.worker.model_runner import _get_graph_batch_size
...
@@ -10,7 +10,7 @@ from vllm.worker.model_runner import _get_graph_batch_size
from
...utils
import
check_outputs_equal
from
...utils
import
check_outputs_equal
MODELS
=
[
"state-spaces/mamba-130m-hf"
]
MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"tiiuae/falcon-mamba-tiny-dev"
]
# Use lower-level interfaces to create this greedy generator, as mamba will
# Use lower-level interfaces to create this greedy generator, as mamba will
...
...
vllm/model_executor/layers/layernorm.py
View file @
f6b97293
...
@@ -27,7 +27,6 @@ class RMSNorm(CustomOp):
...
@@ -27,7 +27,6 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
else
var_hidden_size
)
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
def
forward_native
(
def
forward_native
(
...
...
vllm/model_executor/models/mamba.py
View file @
f6b97293
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
...
@@ -59,7 +59,7 @@ class MambaMixer(nn.Module):
...
@@ -59,7 +59,7 @@ class MambaMixer(nn.Module):
self
.
conv_kernel_size
=
config
.
conv_kernel
self
.
conv_kernel_size
=
config
.
conv_kernel
self
.
intermediate_size
=
config
.
intermediate_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
time_step_rank
=
int
(
config
.
time_step_rank
)
self
.
time_step_rank
=
int
(
config
.
time_step_rank
)
self
.
is_falcon_mamba
=
config
.
model_type
==
"falcon_mamba"
self
.
conv1d
=
ColumnParallelLinear
(
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
intermediate_size
,
output_size
=
self
.
intermediate_size
,
...
@@ -109,6 +109,13 @@ class MambaMixer(nn.Module):
...
@@ -109,6 +109,13 @@ class MambaMixer(nn.Module):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
)
)
self
.
activation
=
config
.
hidden_act
self
.
activation
=
config
.
hidden_act
if
self
.
is_falcon_mamba
:
self
.
dt_layernorm
=
RMSNorm
(
self
.
time_step_rank
,
eps
=
config
.
mixer_rms_eps
)
self
.
b_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
mixer_rms_eps
)
self
.
c_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
mixer_rms_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
...
@@ -158,8 +165,12 @@ class MambaMixer(nn.Module):
...
@@ -158,8 +165,12 @@ class MambaMixer(nn.Module):
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
dim
=-
1
,
dim
=-
1
,
)
)
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
# but Mamba doesn't.
if
self
.
is_falcon_mamba
:
time_step
=
self
.
dt_layernorm
(
time_step
.
contiguous
())
B
=
self
.
b_layernorm
(
B
.
contiguous
())
C
=
self
.
c_layernorm
(
C
.
contiguous
())
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
...
@@ -213,11 +224,9 @@ class MambaDecoderLayer(nn.Module):
...
@@ -213,11 +224,9 @@ class MambaDecoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
self
.
is_falcon_mamba
=
config
.
model_type
==
"falcon_mamba"
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -319,8 +328,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -319,8 +328,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
backbone
.
embeddings
self
.
lm_head
=
self
.
backbone
.
embeddings
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
# Used to track and store by the Mamba cache between steps.
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
...
@@ -398,7 +417,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -398,7 +417,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"A_log"
in
name
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
name
=
name
.
replace
(
"A_log"
,
"A"
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
...
...
vllm/model_executor/models/registry.py
View file @
f6b97293
...
@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
# For decapoda-research/llama-*
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
"MambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
"FalconMambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
"MistralForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment