Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a89209b7
Unverified
Commit
a89209b7
authored
Jun 19, 2025
by
Chen Zhang
Committed by
GitHub
Jun 18, 2025
Browse files
[v1] Support mamba2 (#19327)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
ffacb222
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
583 additions
and
121 deletions
+583
-121
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+42
-11
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+175
-60
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+33
-21
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+192
-0
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+42
-1
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+24
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+68
-26
No files found.
tests/models/language/generation/test_hybrid.py
View file @
a89209b7
...
...
@@ -17,9 +17,10 @@ SSM_MODELS = [
"state-spaces/mamba-130m-hf"
,
"tiiuae/falcon-mamba-tiny-dev"
,
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
#
"mistralai/Mamba-Codestral-7B-v0.1",
"mistralai/Mamba-Codestral-7B-v0.1"
,
]
HYBRID_MODELS
=
[
...
...
@@ -35,6 +36,10 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM"
,
]
V1_SUPPORTED_MODELS
=
[
"mistralai/Mamba-Codestral-7B-v0.1"
,
]
# Avoid OOM
MAX_NUM_SEQS
=
4
...
...
@@ -46,23 +51,49 @@ def test_models(
hf_runner
,
vllm_runner
,
example_prompts
,
monkeypatch
,
model
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
hf_runner
(
model
)
as
hf_model
:
if
model
!=
"mistralai/Mamba-Codestral-7B-v0.1"
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
hf_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
in
V1_SUPPORTED_MODELS
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
enforce_eager
=
True
,
enable_prefix_caching
=
False
)
as
vllm_model
:
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v1_outputs
=
None
if
hf_outputs
is
not
None
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
outputs_1_lst
=
vllm_
v0_
outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm-v0"
,
)
if
model
in
V1_SUPPORTED_MODELS
:
ref_outputs
=
hf_outputs
if
hf_outputs
is
not
None
else
vllm_v0_outputs
check_logprobs_close
(
outputs_0_lst
=
ref_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
name_0
=
"hf"
if
hf_outputs
is
not
None
else
"vllm-v0"
,
name_1
=
"vllm-v1"
,
)
...
...
tests/v1/test_oracle.py
View file @
a89209b7
...
...
@@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1
=
[
"openai/whisper-large-v3"
,
# transcription
"facebook/bart-large-cnn"
,
# encoder decoder
"
mistralai/Mamba-Codestral-7B-v0.1
"
,
# mamba
"
state-spaces/mamba-130m-hf
"
,
# mamba
1
"hmellor/tiny-random-BambaForCausalLM"
,
# hybrid
"BAAI/bge-m3"
,
# embedding
]
...
...
vllm/engine/arg_utils.py
View file @
a89209b7
...
...
@@ -1355,12 +1355,17 @@ class EngineArgs:
recommend_to_remove
=
False
)
return
False
# No
Mamba or
Encoder-Decoder so far.
# No Encoder-Decoder
, not all Mamba
so far.
if
not
model_config
.
is_v1_compatible
:
_raise_or_fallback
(
feature_name
=
model_config
.
architectures
,
recommend_to_remove
=
False
)
return
False
# V1 mamba models are unoptimized.
if
model_config
.
has_inner_state
and
_warn_or_fallback
(
feature_name
=
"Mamba"
):
return
False
# No Concurrent Partial Prefills so far.
if
(
self
.
max_num_partial_prefills
!=
SchedulerConfig
.
max_num_partial_prefills
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
a89209b7
...
...
@@ -6,7 +6,9 @@ from typing import Optional, Union
import
torch
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
...
...
@@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionMetadata
# Added by the IBM Team, 2024
...
...
@@ -241,6 +244,8 @@ class MambaMixer2(CustomOp):
activation
:
str
=
"silu"
,
use_rms_norm
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
chunk_size
:
int
=
-
1
,
# the chunk size used by v1
):
super
().
__init__
()
...
...
@@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
),
"Tensor parallel currently not supported for quantized models."
self
.
ssm_state_size
=
ssm_state_size
self
.
conv_kernel_size
=
conv_kernel_size
self
.
activation
=
activation
self
.
intermediate_size
=
intermediate_size
...
...
@@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
self
.
use_rms_norm
,
eps
=
rms_norm_eps
)
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
assert
chunk_size
!=
-
1
,
"chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self
.
chunk_size
=
chunk_size
self
.
prefix
=
prefix
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
forward_context
=
get_forward_context
()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
num_prefills
=
attn_metadata
.
num_prefills
# request count
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
]
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx
chunk_indices_p
=
attn_metadata
.
chunk_indices
chunk_offsets_p
=
attn_metadata
.
chunk_offsets
else
:
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
has_initial_states_p
=
mamba2_metadata
.
has_initial_states
prep_initial_states
=
mamba2_metadata
.
prep_initial_states
chunk_size
=
mamba2_metadata
.
chunk_size
seq_idx_p
=
mamba2_metadata
.
seq_idx
chunk_indices_p
=
mamba2_metadata
.
chunk_indices
chunk_offsets_p
=
mamba2_metadata
.
chunk_offsets
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
...
...
@@ -459,8 +501,57 @@ class MambaMixer2(CustomOp):
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
hidden_states_B_C
,
[
self
.
intermediate_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
],
dim
=-
1
,
)
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
# V1 profile run
hidden_states_B_C
=
(
hidden_states_B_C
.
transpose
(
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
hidden_states
,
_B
,
_C
=
split_hidden_states_B_C_fn
(
hidden_states_B_C
)
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
out
,
_
=
self
.
out_proj
(
hidden_states
)
return
out
num_prefills
=
attn_metadata
.
num_prefills
# request count
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if
envs
.
VLLM_USE_V1
:
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
dt_d
,
dt_p
=
torch
.
split
(
dt
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
else
:
hidden_states_B_C_p
,
hidden_states_B_C_d
=
torch
.
split
(
hidden_states_B_C
,
[
num_prefill_tokens
,
num_decodes
],
...
...
@@ -473,38 +564,28 @@ class MambaMixer2(CustomOp):
)
# Split along batch dimension
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
mamba_cache_params
.
state_indices_tensor
,
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
if
has_prefill
else
None
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
hidden_states_B_C
,
[
self
.
intermediate_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
],
dim
=-
1
,
)
ssd_output_list
=
[]
# Process prefill requests
if
has_prefill
:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "
mamba_cache_params.
state_indices_tensor"
# pointed to by "state_indices_tensor"
hidden_states_B_C_p
=
causal_conv1d_fn
(
hidden_states_B_C_p
.
transpose
(
0
,
1
),
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
mamba2_metadata
.
has_initial_states
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states
_p
,
cache_indices
=
state_indices_tensor_p
,
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
...
...
@@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
# 3. State Space Model sequence transformation
initial_states
=
None
if
(
mamba2_metadata
.
has_initial_states
is
not
None
and
mamba2_metadata
.
prep_initial_states
):
if
(
has_initial_states_p
is
not
None
and
prep_initial_states
):
# making a copy of the states
initial_states
=
torch
.
where
(
mamba2_metadata
.
has_initial_states
[:,
None
,
None
,
None
],
mamba_cache_params
.
ssm_state
[
state_indices_tensor_p
],
0
)
has_initial_states
_p
[:,
None
,
None
,
None
],
ssm_state
[
state_indices_tensor_p
],
0
)
scan_output
,
varlen_state
=
mamba_chunk_scan_combined
(
hidden_states_p
.
view
(
1
,
num_prefill_tokens
,
...
...
@@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
-
1
),
C_p
.
view
(
1
,
num_prefill_tokens
,
self
.
n_groups
//
self
.
tp_size
,
-
1
),
chunk_size
=
mamba2_metadata
.
chunk_size
,
chunk_size
=
chunk_size
,
D
=
self
.
D
,
z
=
None
,
dt_bias
=
self
.
dt_bias
,
seq_idx
=
mamba2_metadata
.
seq_idx
,
chunk_indices
=
mamba2_metadata
.
chunk_indices
,
chunk_offsets
=
mamba2_metadata
.
chunk_offsets
,
cu_seqlens
=
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
,
seq_idx
=
seq_idx
_p
,
chunk_indices
=
chunk_indices
_p
,
chunk_offsets
=
chunk_offsets
_p
,
cu_seqlens
=
query_start_loc
_p
,
initial_states
=
initial_states
,
return_varlen_states
=
True
,
return_final_states
=
False
,
...
...
@@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params
.
ssm_state
[
state_indices_tensor_p
]
=
varlen_state
ssm_state
[
state_indices_tensor_p
]
=
varlen_state
# - reshape
ssd_output_list
.
append
(
scan_output
.
view
(
num_prefill_tokens
,
-
1
))
...
...
@@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
hidden_states_B_C_d
=
causal_conv1d_update
(
hidden_states_B_C_d
,
mamba_cache_params
.
conv_state
,
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
...
...
@@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
# using state_indices_tensor_d
hidden_states_d
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
ssm_state
,
hidden_states_d
,
dt_d
,
A_d
,
...
...
@@ -598,6 +678,13 @@ class MambaMixer2(CustomOp):
dt_softplus
=
True
,
state_batch_indices
=
state_indices_tensor_d
,
)
if
envs
.
VLLM_USE_V1
:
ssd_output_list
.
insert
(
0
,
hidden_states_d
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
self
.
head_dim
))
else
:
ssd_output_list
.
append
(
hidden_states_d
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
self
.
head_dim
))
...
...
@@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
# 5. Final linear projection
out
,
_
=
self
.
out_proj
(
hidden_states
)
return
out
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
world_size
=
get_tensor_model_parallel_world_size
()
conv_state_shape
,
temporal_state_shape
=
None
,
None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups
=
(
self
.
n_groups
+
extra_groups_for_head_shards
(
self
.
n_groups
,
world_size
))
# - heads and n_groups are TP-ed
conv_dim
=
(
self
.
intermediate_size
+
2
*
n_groups
*
self
.
ssm_state_size
)
conv_state_shape
=
(
divide
(
conv_dim
,
world_size
),
self
.
conv_kernel_size
-
1
,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape
=
(
divide
(
self
.
num_heads
,
world_size
),
self
.
head_dim
,
self
.
ssm_state_size
,
)
return
conv_state_shape
,
temporal_state_shape
vllm/model_executor/models/mamba2.py
View file @
a89209b7
...
...
@@ -8,6 +8,7 @@ import torch
from
torch
import
nn
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
...
...
@@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
,
SupportsV0Only
)
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
MambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
mixer
=
MambaMixer2
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
rms_norm_eps
=
config
.
layer_norm_epsilon
,
activation
=
config
.
hidden_act
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mixer"
,
chunk_size
=
config
.
chunk_size
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Mamba2DecoderLayer
(
config
,
quant_config
=
quant_config
),
lambda
prefix
:
Mamba2DecoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
...
...
@@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
),
i
-
self
.
start_layer
)
if
mamba_cache_params
else
None
,
mamba2_metadata
=
mamba2_metadata
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
return
loaded_params
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
SupportsV0Only
):
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
())
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
else
:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
...
...
vllm/v1/attention/backends/mamba_attn.py
0 → 100644
View file @
a89209b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
_query_start_loc_to_chunk_indices_offsets
)
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
MambaSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
def
get_mamba2_chunk_size
(
vllm_config
:
VllmConfig
)
->
int
:
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
layers
=
get_layers_from_vllm_config
(
vllm_config
,
MambaMixer2
)
chunk_sizes
=
set
(
layer
.
chunk_size
for
layer
in
layers
.
values
())
assert
len
(
chunk_sizes
)
==
1
,
"All Mamba2 layers must have the same chunk size"
return
chunk_sizes
.
pop
()
class
Mamba2AttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_builder_cls
()
->
type
[
"Mamba2AttentionMetadataBuilder"
]:
return
Mamba2AttentionMetadataBuilder
@
dataclass
class
Mamba2AttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
has_initial_states
:
torch
.
Tensor
prep_initial_states
:
bool
chunk_size
:
int
seq_idx
:
torch
.
Tensor
chunk_indices
:
torch
.
Tensor
chunk_offsets
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
class
Mamba2AttentionMetadataBuilder
(
AttentionMetadataBuilder
[
Mamba2AttentionMetadata
]):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
MambaSpec
,
block_table
:
BlockTable
):
self
.
runner
=
runner
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
self
.
chunk_size
=
get_mamba2_chunk_size
(
runner
.
vllm_config
)
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes
=
[]
prefills
=
[]
num_decode_tokens
=
0
num_prefill_tokens
=
0
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if
num_tokens
==
1
:
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
else
:
prefills
.
append
(
i
)
num_prefill_tokens
+=
num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes
=
len
(
decodes
)
num_prefills
=
len
(
prefills
)
modified_batch
=
False
for
i
in
range
(
1
,
min
(
num_decodes
,
num_prefills
)
+
1
):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx
=
decodes
[
num_decodes
-
i
]
if
decode_idx
<
num_decodes
:
break
input_batch
.
swap_states
(
prefills
[
i
-
1
],
decode_idx
)
modified_batch
=
True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self
.
_num_decodes
=
num_decodes
self
.
_num_prefills
=
num_prefills
self
.
_num_decode_tokens
=
num_decode_tokens
self
.
_num_prefill_tokens
=
num_prefill_tokens
return
modified_batch
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_idx
=
None
chunk_indices
,
chunk_offsets
=
None
,
None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states
=
None
prep_initial_states
=
False
state_indices_tensor
=
self
.
block_table
.
block_table
[:
num_reqs
,
0
]
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if
self
.
_num_prefills
>
0
:
#[batch,]
has_initial_states_cpu
=
(
self
.
runner
.
input_batch
.
num_computed_tokens_cpu_tensor
[
num_reqs
-
self
.
_num_prefills
:
num_reqs
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_cpu
).
item
()
has_initial_states
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
-
self
.
_num_prefills
-
1
:]
-
self
.
_num_decode_tokens
seq_idx
=
torch
.
repeat_interleave
(
torch
.
arange
(
self
.
_num_prefills
,
dtype
=
torch
.
int32
,
device
=
query_start_loc_p
.
device
),
query_start_loc_p
.
diff
(),
output_size
=
self
.
_num_prefill_tokens
)
seq_idx
.
unsqueeze_
(
0
)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if
prep_initial_states
:
chunk_indices
,
chunk_offsets
=
(
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc_p
,
self
.
chunk_size
,
self
.
_num_prefill_tokens
))
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
self
.
_num_prefills
,
num_prefill_tokens
=
self
.
_num_prefill_tokens
,
num_decodes
=
self
.
_num_decodes
,
num_decode_tokens
=
self
.
_num_decode_tokens
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
has_initial_states
=
has_initial_states
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
state_indices_tensor
=
state_indices_tensor
,
)
return
attn_metadata
vllm/v1/core/single_type_kv_cache_manager.py
View file @
a89209b7
...
...
@@ -8,7 +8,7 @@ from vllm.utils import cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
SlidingWindowSpec
)
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.request
import
Request
...
...
@@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self
.
caching_hash_fn
=
caching_hash_fn
self
.
kv_cache_group_id
=
kv_cache_group_id
self
.
_null_block
=
block_pool
.
null_block
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
...
...
@@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return
0
class
MambaManager
(
SingleTypeKVCacheManager
):
@
classmethod
def
find_longest_cache_hit
(
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
assert
isinstance
(
kv_cache_spec
,
MambaSpec
),
(
"MambaManager can only be used for mamba groups"
)
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
# Each request will always have 1 block at this moment, so no need to
# remove blocks.
pass
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
return
0
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
KVCacheBlock
]:
new_blocks
=
super
().
allocate_new_blocks
(
request_id
,
num_tokens
)
assert
len
(
self
.
req_to_blocks
[
request_id
])
==
1
,
(
"MambaManager should only allocate 1 block for each request."
)
return
new_blocks
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
MambaSpec
:
MambaManager
,
}
...
...
vllm/v1/kv_cache_interface.py
View file @
a89209b7
...
...
@@ -3,6 +3,7 @@
import
copy
from
dataclasses
import
dataclass
from
math
import
prod
from
typing
import
Optional
import
torch
...
...
@@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
class
MambaSpec
(
KVCacheSpec
):
shapes
:
tuple
[
tuple
[
int
,
...],
...]
dtype
:
torch
.
dtype
def
__post_init__
(
self
):
self
.
num_elements
=
sum
(
prod
(
shape
)
for
shape
in
self
.
shapes
)
@
property
def
type_id
(
self
)
->
str
:
return
f
"mamba_
{
self
.
shapes
}
_
{
self
.
dtype
}
"
@
property
def
page_size_bytes
(
self
)
->
int
:
return
self
.
num_elements
*
get_dtype_size
(
self
.
dtype
)
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return
self
.
page_size_bytes
@
dataclass
class
KVCacheTensor
:
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a89209b7
...
...
@@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
KVCacheConfig
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
...
...
@@ -2093,9 +2096,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
i
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
raise
NotImplementedError
(
"Only AttentionSpec is supported for now."
)
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
attn_backend_i
=
get_attn_backend
(
kv_cache_spec
.
head_size
,
self
.
dtype
,
...
...
@@ -2105,8 +2106,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_mla
=
kv_cache_spec
.
use_mla
,
)
if
attn_backend_i
is
None
:
error_msg
=
(
f
"Error with get_attn_backend:
{
kv_cache_spec
.
head_size
=
}
, "
error_msg
=
(
f
"Error with get_attn_backend: "
f
"
{
kv_cache_spec
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
kv_cache_spec
.
dtype
=
}
, "
f
"
{
kv_cache_spec
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
...
...
@@ -2115,6 +2116,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
attn_backend_i
=
Mamba2AttentionBackend
else
:
raise
ValueError
(
f
"Unknown KV cache spec type:
{
type
(
kv_cache_spec
)
}
"
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
...
...
@@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
*
inv_order
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
]
dtype
=
kv_cache_spec
.
dtype
state_tensors
=
[]
start_pos
=
0
for
shape
in
kv_cache_spec
.
shapes
:
target_shape
=
(
num_blocks
,
*
shape
)
size_in_bytes
=
np
.
prod
(
shape
)
*
get_dtype_size
(
dtype
)
*
num_blocks
tensor
=
raw_tensor
[
start_pos
:
start_pos
+
size_in_bytes
]
tensor
=
tensor
.
view
(
dtype
).
view
(
target_shape
)
state_tensors
.
append
(
tensor
)
start_pos
+=
size_in_bytes
assert
start_pos
==
raw_tensor
.
numel
()
kv_caches
[
layer_name
]
=
tuple
(
state_tensors
)
else
:
raise
NotImplementedError
return
kv_caches
...
...
@@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
layers
.
items
():
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
for
layer_name
,
attn_module
in
attn_layers
.
items
():
if
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
# The layer doesn't need its own KV cache and will use that of
...
...
@@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
ValueError
(
f
"Unknown attention type:
{
attn_module
.
attn_type
}
"
)
mamba_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
MambaMixer2
)
if
len
(
mamba_layers
)
>
0
:
if
self
.
vllm_config
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
)
if
not
self
.
vllm_config
.
model_config
.
enforce_eager
:
raise
NotImplementedError
(
"Mamba with cuda graph is not supported yet."
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
raise
NotImplementedError
(
"Prefix caching is not supported for Mamba yet."
)
max_model_len
=
self
.
vllm_config
.
model_config
.
max_model_len
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for
layer_name
,
mamba_module
in
mamba_layers
.
items
():
kv_cache_spec
[
layer_name
]
=
MambaSpec
(
shapes
=
mamba_module
.
get_state_shape
(),
dtype
=
self
.
kv_cache_dtype
,
block_size
=
max_model_len
)
return
kv_cache_spec
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment