Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
603a661a
Unverified
Commit
603a661a
authored
Nov 04, 2024
by
Mor Zusman
Committed by
GitHub
Nov 04, 2024
Browse files
[Model] factoring out MambaMixer out of Jamba (#8993)
Signed-off-by:
mzusman
<
mor.zusmann@gmail.com
>
parent
fb2716d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
245 additions
and
374 deletions
+245
-374
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+217
-0
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+14
-185
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+14
-189
No files found.
vllm/model_executor/layers/mamba/mamba_mixer.py
0 → 100644
View file @
603a661a
import
torch
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@
CustomOp
.
register
(
"mamba_mixer"
)
class
MambaMixer
(
CustomOp
):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def
__init__
(
self
,
hidden_size
:
int
,
ssm_state_size
:
int
,
conv_kernel_size
:
int
,
intermediate_size
:
int
,
time_step_rank
:
int
,
use_conv_bias
:
bool
,
use_bias
:
bool
,
use_rms_norm
:
bool
,
rms_norm_eps
:
float
=
1e-5
,
activation
=
"silu"
):
super
().
__init__
()
self
.
time_step_rank
=
time_step_rank
self
.
ssm_state_size
=
ssm_state_size
self
.
use_rms_norm
=
use_rms_norm
self
.
activation
=
activation
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
conv_kernel_size
,
output_size
=
intermediate_size
,
bias
=
use_conv_bias
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
use_bias
)
# selective projection used to make dt, B and C input dependent
self
.
x_proj
=
RowParallelLinear
(
intermediate_size
,
time_step_rank
+
ssm_state_size
*
2
,
bias
=
False
,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self
.
dt_proj
=
ColumnParallelLinear
(
time_step_rank
,
intermediate_size
,
bias
=
True
,
skip_bias_add
=
True
)
def
weight_loader
(
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
param
.
data
.
copy_
(
loaded_weight
.
data
.
split
(
loaded_weight
.
shape
[
0
]
//
tp_size
,
dim
=
0
)[
tp_rank
])
def
A_weight_loader
(
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
weight_loader
(
param
,
-
torch
.
exp
(
loaded_weight
.
float
()))
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
intermediate_size
//
tp_size
,
ssm_state_size
,
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
intermediate_size
//
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
A_weight_loader
})
self
.
out_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
use_bias
,
input_is_parallel
=
True
,
)
self
.
dt_layernorm
=
RMSNorm
(
time_step_rank
,
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
self
.
b_layernorm
=
RMSNorm
(
ssm_state_size
,
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
self
.
c_layernorm
=
RMSNorm
(
ssm_state_size
,
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
dim
=-
1
,
)
if
self
.
use_rms_norm
:
assert
self
.
dt_layernorm
is
not
None
assert
self
.
b_layernorm
is
not
None
assert
self
.
c_layernorm
is
not
None
time_step
=
self
.
dt_layernorm
(
time_step
.
contiguous
())
B
=
self
.
b_layernorm
(
B
.
contiguous
())
C
=
self
.
c_layernorm
(
C
.
contiguous
())
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
vllm/model_executor/models/jamba.py
View file @
603a661a
...
@@ -12,26 +12,19 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
...
@@ -12,26 +12,19 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
_get_graph_batch_size
)
...
@@ -41,179 +34,6 @@ from .interfaces import HasInnerState, SupportsLoRA
...
@@ -41,179 +34,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
JambaMambaMixer
(
nn
.
Module
):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def
__init__
(
self
,
config
:
JambaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
mamba_d_state
self
.
conv_kernel_size
=
config
.
mamba_d_conv
self
.
intermediate_size
=
config
.
mamba_expand
*
config
.
hidden_size
self
.
time_step_rank
=
config
.
mamba_dt_rank
self
.
use_conv_bias
=
config
.
mamba_conv_bias
self
.
use_bias
=
config
.
mamba_proj_bias
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
intermediate_size
,
bias
=
self
.
use_conv_bias
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
self
.
use_bias
)
# selective projection used to make dt, B and C input dependent
self
.
x_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
time_step_rank
+
self
.
ssm_state_size
*
2
,
bias
=
False
,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self
.
dt_proj
=
ColumnParallelLinear
(
self
.
time_step_rank
,
self
.
intermediate_size
,
bias
=
True
,
skip_bias_add
=
True
)
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
self
.
intermediate_size
//
tp_size
,
self
.
ssm_state_size
,
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
intermediate_size
//
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
sharded_weight_loader
(
0
),
lambda
x
:
-
torch
.
exp
(
x
.
float
()))
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
a_weight_loader
})
self
.
out_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
self
.
use_bias
,
input_is_parallel
=
True
,
)
self
.
activation
=
config
.
hidden_act
self
.
dt_layernorm
=
RMSNorm
(
self
.
time_step_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
b_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
rms_norm_eps
)
self
.
c_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
dim
=-
1
,
)
time_step
=
self
.
dt_layernorm
(
time_step
.
contiguous
())
B
=
self
.
b_layernorm
(
B
.
contiguous
())
C
=
self
.
c_layernorm
(
C
.
contiguous
())
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
class
JambaMoE
(
nn
.
Module
):
class
JambaMoE
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -284,9 +104,18 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -284,9 +104,18 @@ class JambaMambaDecoderLayer(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
self
.
mamba
=
JambaMambaMixer
(
config
)
self
.
mamba
=
MambaMixer
(
hidden_size
=
config
.
hidden_size
,
ssm_state_size
=
config
.
mamba_d_state
,
conv_kernel_size
=
config
.
mamba_d_conv
,
intermediate_size
=
config
.
mamba_expand
*
\
config
.
hidden_size
,
time_step_rank
=
config
.
mamba_dt_rank
,
use_conv_bias
=
config
.
mamba_conv_bias
,
use_bias
=
config
.
mamba_proj_bias
,
use_rms_norm
=
True
,
rms_norm_eps
=
config
.
rms_norm_eps
,
activation
=
config
.
hidden_act
)
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
...
...
vllm/model_executor/models/mamba.py
View file @
603a661a
...
@@ -10,27 +10,19 @@ from vllm.attention.backends.abstract import AttentionMetadata
...
@@ -10,27 +10,19 @@ from vllm.attention.backends.abstract import AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
_get_graph_batch_size
)
...
@@ -38,194 +30,27 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
...
@@ -38,194 +30,27 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
MambaMixer
(
nn
.
Module
):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def
__init__
(
self
,
config
:
MambaConfig
,
layer_idx
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
state_size
self
.
conv_kernel_size
=
config
.
conv_kernel
self
.
intermediate_size
=
config
.
intermediate_size
self
.
time_step_rank
=
int
(
config
.
time_step_rank
)
self
.
is_falcon_mamba
=
config
.
model_type
==
"falcon_mamba"
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
intermediate_size
,
bias
=
config
.
use_conv_bias
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
config
.
use_bias
)
# selective projection used to make dt, B and C input dependent
self
.
x_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
time_step_rank
+
self
.
ssm_state_size
*
2
,
bias
=
False
,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self
.
dt_proj
=
ColumnParallelLinear
(
self
.
time_step_rank
,
self
.
intermediate_size
,
bias
=
True
,
skip_bias_add
=
True
)
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
self
.
intermediate_size
//
tp_size
,
self
.
ssm_state_size
,
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
intermediate_size
//
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
sharded_weight_loader
(
0
),
lambda
x
:
-
torch
.
exp
(
x
.
float
()))
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
a_weight_loader
})
self
.
out_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
config
.
use_bias
,
input_is_parallel
=
True
,
)
self
.
activation
=
config
.
hidden_act
if
self
.
is_falcon_mamba
:
self
.
dt_layernorm
=
RMSNorm
(
self
.
time_step_rank
,
eps
=
config
.
mixer_rms_eps
)
self
.
b_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
mixer_rms_eps
)
self
.
c_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
mixer_rms_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
dim
=-
1
,
)
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
# but Mamba doesn't.
if
self
.
is_falcon_mamba
:
time_step
=
self
.
dt_layernorm
(
time_step
.
contiguous
())
B
=
self
.
b_layernorm
(
B
.
contiguous
())
C
=
self
.
c_layernorm
(
C
.
contiguous
())
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
class
MambaDecoderLayer
(
nn
.
Module
):
class
MambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
MambaConfig
,
config
:
MambaConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
self
.
is_falcon_mamba
=
config
.
model_type
==
"falcon_mamba"
self
.
is_falcon_mamba
=
config
.
model_type
==
"falcon_mamba"
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
mixer_rms_rps
=
config
.
mixer_rms_rps
if
self
.
is_falcon_mamba
else
None
self
.
mamba
=
MambaMixer
(
hidden_size
=
config
.
hidden_size
,
ssm_state_size
=
config
.
state_size
,
conv_kernel_size
=
config
.
conv_kernel
,
intermediate_size
=
config
.
intermediate_size
,
time_step_rank
=
config
.
time_step_rank
,
use_conv_bias
=
config
.
use_conv_bias
,
use_bias
=
config
.
use_bias
,
use_rms_norm
=
self
.
is_falcon_mamba
,
rms_norm_eps
=
mixer_rms_rps
,
activation
=
config
.
hidden_act
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment