Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a89209b7
Unverified
Commit
a89209b7
authored
Jun 19, 2025
by
Chen Zhang
Committed by
GitHub
Jun 18, 2025
Browse files
[v1] Support mamba2 (#19327)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
ffacb222
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
583 additions
and
121 deletions
+583
-121
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+42
-11
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+175
-60
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+33
-21
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+192
-0
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+42
-1
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+24
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+68
-26
No files found.
tests/models/language/generation/test_hybrid.py
View file @
a89209b7
...
@@ -17,9 +17,10 @@ SSM_MODELS = [
...
@@ -17,9 +17,10 @@ SSM_MODELS = [
"state-spaces/mamba-130m-hf"
,
"state-spaces/mamba-130m-hf"
,
"tiiuae/falcon-mamba-tiny-dev"
,
"tiiuae/falcon-mamba-tiny-dev"
,
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
# See https://github.com/huggingface/transformers/pull/35943
#
"mistralai/Mamba-Codestral-7B-v0.1",
"mistralai/Mamba-Codestral-7B-v0.1"
,
]
]
HYBRID_MODELS
=
[
HYBRID_MODELS
=
[
...
@@ -35,6 +36,10 @@ HYBRID_MODELS = [
...
@@ -35,6 +36,10 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM"
,
"hmellor/tiny-random-BambaForCausalLM"
,
]
]
V1_SUPPORTED_MODELS
=
[
"mistralai/Mamba-Codestral-7B-v0.1"
,
]
# Avoid OOM
# Avoid OOM
MAX_NUM_SEQS
=
4
MAX_NUM_SEQS
=
4
...
@@ -46,23 +51,49 @@ def test_models(
...
@@ -46,23 +51,49 @@ def test_models(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
monkeypatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
None
:
)
->
None
:
with
hf_runner
(
model
)
as
hf_model
:
with
hf_runner
(
model
)
as
hf_model
:
if
model
!=
"mistralai/Mamba-Codestral-7B-v0.1"
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
hf_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
in
V1_SUPPORTED_MODELS
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
enforce_eager
=
True
,
enable_prefix_caching
=
False
)
as
vllm_model
:
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v1_outputs
=
None
if
hf_outputs
is
not
None
:
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
outputs_1_lst
=
vllm_
v0_
outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm-v0"
,
)
if
model
in
V1_SUPPORTED_MODELS
:
ref_outputs
=
hf_outputs
if
hf_outputs
is
not
None
else
vllm_v0_outputs
check_logprobs_close
(
outputs_0_lst
=
ref_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
name_0
=
"hf"
if
hf_outputs
is
not
None
else
"vllm-v0"
,
name_1
=
"vllm-v1"
,
)
)
...
...
tests/v1/test_oracle.py
View file @
a89209b7
...
@@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1
=
[
UNSUPPORTED_MODELS_V1
=
[
"openai/whisper-large-v3"
,
# transcription
"openai/whisper-large-v3"
,
# transcription
"facebook/bart-large-cnn"
,
# encoder decoder
"facebook/bart-large-cnn"
,
# encoder decoder
"
mistralai/Mamba-Codestral-7B-v0.1
"
,
# mamba
"
state-spaces/mamba-130m-hf
"
,
# mamba
1
"hmellor/tiny-random-BambaForCausalLM"
,
# hybrid
"hmellor/tiny-random-BambaForCausalLM"
,
# hybrid
"BAAI/bge-m3"
,
# embedding
"BAAI/bge-m3"
,
# embedding
]
]
...
...
vllm/engine/arg_utils.py
View file @
a89209b7
...
@@ -1355,12 +1355,17 @@ class EngineArgs:
...
@@ -1355,12 +1355,17 @@ class EngineArgs:
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
return
False
return
False
# No
Mamba or
Encoder-Decoder so far.
# No Encoder-Decoder
, not all Mamba
so far.
if
not
model_config
.
is_v1_compatible
:
if
not
model_config
.
is_v1_compatible
:
_raise_or_fallback
(
feature_name
=
model_config
.
architectures
,
_raise_or_fallback
(
feature_name
=
model_config
.
architectures
,
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
return
False
return
False
# V1 mamba models are unoptimized.
if
model_config
.
has_inner_state
and
_warn_or_fallback
(
feature_name
=
"Mamba"
):
return
False
# No Concurrent Partial Prefills so far.
# No Concurrent Partial Prefills so far.
if
(
self
.
max_num_partial_prefills
if
(
self
.
max_num_partial_prefills
!=
SchedulerConfig
.
max_num_partial_prefills
!=
SchedulerConfig
.
max_num_partial_prefills
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
a89209b7
...
@@ -6,7 +6,9 @@ from typing import Optional, Union
...
@@ -6,7 +6,9 @@ from typing import Optional, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
...
@@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionMetadata
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
...
@@ -241,6 +244,8 @@ class MambaMixer2(CustomOp):
...
@@ -241,6 +244,8 @@ class MambaMixer2(CustomOp):
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_rms_norm
:
bool
=
True
,
use_rms_norm
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
chunk_size
:
int
=
-
1
,
# the chunk size used by v1
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
...
@@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
),
"Tensor parallel currently not supported for quantized models."
),
"Tensor parallel currently not supported for quantized models."
self
.
ssm_state_size
=
ssm_state_size
self
.
ssm_state_size
=
ssm_state_size
self
.
conv_kernel_size
=
conv_kernel_size
self
.
activation
=
activation
self
.
activation
=
activation
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
...
@@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
...
@@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
self
.
use_rms_norm
,
self
.
use_rms_norm
,
eps
=
rms_norm_eps
)
eps
=
rms_norm_eps
)
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
assert
chunk_size
!=
-
1
,
"chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self
.
chunk_size
=
chunk_size
self
.
prefix
=
prefix
def
forward_native
(
def
forward_native
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
...
@@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
mamba2_metadata
:
Mamba2Metadata
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
forward_context
=
get_forward_context
()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
num_prefills
=
attn_metadata
.
num_prefills
# request count
if
attn_metadata
is
not
None
:
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
assert
isinstance
(
attn_metadata
,
dict
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
attn_metadata
=
attn_metadata
[
self
.
prefix
]
has_prefill
=
num_prefills
>
0
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
has_decode
=
num_decodes
>
0
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
]
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx
chunk_indices_p
=
attn_metadata
.
chunk_indices
chunk_offsets_p
=
attn_metadata
.
chunk_offsets
else
:
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
has_initial_states_p
=
mamba2_metadata
.
has_initial_states
prep_initial_states
=
mamba2_metadata
.
prep_initial_states
chunk_size
=
mamba2_metadata
.
chunk_size
seq_idx_p
=
mamba2_metadata
.
seq_idx
chunk_indices_p
=
mamba2_metadata
.
chunk_indices
chunk_offsets_p
=
mamba2_metadata
.
chunk_offsets
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
...
@@ -459,8 +501,57 @@ class MambaMixer2(CustomOp):
...
@@ -459,8 +501,57 @@ class MambaMixer2(CustomOp):
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
self
.
conv1d
.
weight
.
size
(
2
))
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
hidden_states_B_C
,
[
self
.
intermediate_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
],
dim
=-
1
,
)
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
# V1 profile run
hidden_states_B_C
=
(
hidden_states_B_C
.
transpose
(
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
hidden_states
,
_B
,
_C
=
split_hidden_states_B_C_fn
(
hidden_states_B_C
)
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
out
,
_
=
self
.
out_proj
(
hidden_states
)
return
out
num_prefills
=
attn_metadata
.
num_prefills
# request count
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# Split along token dimension
if
envs
.
VLLM_USE_V1
:
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
dt_d
,
dt_p
=
torch
.
split
(
dt
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
else
:
hidden_states_B_C_p
,
hidden_states_B_C_d
=
torch
.
split
(
hidden_states_B_C_p
,
hidden_states_B_C_d
=
torch
.
split
(
hidden_states_B_C
,
hidden_states_B_C
,
[
num_prefill_tokens
,
num_decodes
],
[
num_prefill_tokens
,
num_decodes
],
...
@@ -473,38 +564,28 @@ class MambaMixer2(CustomOp):
...
@@ -473,38 +564,28 @@ class MambaMixer2(CustomOp):
)
)
# Split along batch dimension
# Split along batch dimension
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
mamba_cache_params
.
state_indices_tensor
,
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
[
num_prefills
,
num_decodes
],
dim
=
0
,
dim
=
0
,
)
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
if
has_prefill
else
None
)
if
has_prefill
else
None
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
hidden_states_B_C
,
[
self
.
intermediate_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
],
dim
=-
1
,
)
ssd_output_list
=
[]
ssd_output_list
=
[]
# Process prefill requests
# Process prefill requests
if
has_prefill
:
if
has_prefill
:
# 2. Convolution sequence transformation
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "
mamba_cache_params.
state_indices_tensor"
# pointed to by "state_indices_tensor"
hidden_states_B_C_p
=
causal_conv1d_fn
(
hidden_states_B_C_p
=
causal_conv1d_fn
(
hidden_states_B_C_p
.
transpose
(
0
,
1
),
hidden_states_B_C_p
.
transpose
(
0
,
1
),
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
conv_states
=
conv_state
,
has_initial_state
=
mamba2_metadata
.
has_initial_states
,
has_initial_state
=
has_initial_states
_p
,
cache_indices
=
state_indices_tensor_p
,
cache_indices
=
state_indices_tensor_p
,
query_start_loc
=
query_start_loc_p
).
transpose
(
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
0
,
1
)[:
num_prefill_tokens
]
...
@@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
...
@@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
# 3. State Space Model sequence transformation
# 3. State Space Model sequence transformation
initial_states
=
None
initial_states
=
None
if
(
mamba2_metadata
.
has_initial_states
is
not
None
if
(
has_initial_states_p
is
not
None
and
prep_initial_states
):
and
mamba2_metadata
.
prep_initial_states
):
# making a copy of the states
# making a copy of the states
initial_states
=
torch
.
where
(
initial_states
=
torch
.
where
(
mamba2_metadata
.
has_initial_states
[:,
None
,
None
,
None
],
has_initial_states
_p
[:,
None
,
None
,
None
],
mamba_cache_params
.
ssm_state
[
state_indices_tensor_p
],
0
)
ssm_state
[
state_indices_tensor_p
],
0
)
scan_output
,
varlen_state
=
mamba_chunk_scan_combined
(
scan_output
,
varlen_state
=
mamba_chunk_scan_combined
(
hidden_states_p
.
view
(
1
,
num_prefill_tokens
,
hidden_states_p
.
view
(
1
,
num_prefill_tokens
,
...
@@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
...
@@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
-
1
),
-
1
),
C_p
.
view
(
1
,
num_prefill_tokens
,
self
.
n_groups
//
self
.
tp_size
,
C_p
.
view
(
1
,
num_prefill_tokens
,
self
.
n_groups
//
self
.
tp_size
,
-
1
),
-
1
),
chunk_size
=
mamba2_metadata
.
chunk_size
,
chunk_size
=
chunk_size
,
D
=
self
.
D
,
D
=
self
.
D
,
z
=
None
,
z
=
None
,
dt_bias
=
self
.
dt_bias
,
dt_bias
=
self
.
dt_bias
,
seq_idx
=
mamba2_metadata
.
seq_idx
,
seq_idx
=
seq_idx
_p
,
chunk_indices
=
mamba2_metadata
.
chunk_indices
,
chunk_indices
=
chunk_indices
_p
,
chunk_offsets
=
mamba2_metadata
.
chunk_offsets
,
chunk_offsets
=
chunk_offsets
_p
,
cu_seqlens
=
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
,
cu_seqlens
=
query_start_loc
_p
,
initial_states
=
initial_states
,
initial_states
=
initial_states
,
return_varlen_states
=
True
,
return_varlen_states
=
True
,
return_final_states
=
False
,
return_final_states
=
False
,
...
@@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
...
@@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
# update ssm states
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params
.
ssm_state
[
state_indices_tensor_p
]
=
varlen_state
ssm_state
[
state_indices_tensor_p
]
=
varlen_state
# - reshape
# - reshape
ssd_output_list
.
append
(
scan_output
.
view
(
num_prefill_tokens
,
-
1
))
ssd_output_list
.
append
(
scan_output
.
view
(
num_prefill_tokens
,
-
1
))
...
@@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
...
@@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
# 2. Convolution sequence transformation
hidden_states_B_C_d
=
causal_conv1d_update
(
hidden_states_B_C_d
=
causal_conv1d_update
(
hidden_states_B_C_d
,
hidden_states_B_C_d
,
mamba_cache_params
.
conv_state
,
conv_state
,
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
self
.
activation
,
self
.
activation
,
...
@@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
...
@@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
# using state_indices_tensor_d
# using state_indices_tensor_d
hidden_states_d
=
selective_state_update
(
hidden_states_d
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
ssm_state
,
hidden_states_d
,
hidden_states_d
,
dt_d
,
dt_d
,
A_d
,
A_d
,
...
@@ -598,6 +678,13 @@ class MambaMixer2(CustomOp):
...
@@ -598,6 +678,13 @@ class MambaMixer2(CustomOp):
dt_softplus
=
True
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices_tensor_d
,
state_batch_indices
=
state_indices_tensor_d
,
)
)
if
envs
.
VLLM_USE_V1
:
ssd_output_list
.
insert
(
0
,
hidden_states_d
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
self
.
head_dim
))
else
:
ssd_output_list
.
append
(
ssd_output_list
.
append
(
hidden_states_d
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
hidden_states_d
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
self
.
head_dim
))
self
.
head_dim
))
...
@@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
...
@@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
# 5. Final linear projection
# 5. Final linear projection
out
,
_
=
self
.
out_proj
(
hidden_states
)
out
,
_
=
self
.
out_proj
(
hidden_states
)
return
out
return
out
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
world_size
=
get_tensor_model_parallel_world_size
()
conv_state_shape
,
temporal_state_shape
=
None
,
None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups
=
(
self
.
n_groups
+
extra_groups_for_head_shards
(
self
.
n_groups
,
world_size
))
# - heads and n_groups are TP-ed
conv_dim
=
(
self
.
intermediate_size
+
2
*
n_groups
*
self
.
ssm_state_size
)
conv_state_shape
=
(
divide
(
conv_dim
,
world_size
),
self
.
conv_kernel_size
-
1
,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape
=
(
divide
(
self
.
num_heads
,
world_size
),
self
.
head_dim
,
self
.
ssm_state_size
,
)
return
conv_state_shape
,
temporal_state_shape
vllm/model_executor/models/mamba2.py
View file @
a89209b7
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MambaConfig
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
...
@@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
,
IsAttentionFree
)
SupportsV0Only
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
MambaConfig
,
config
:
MambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
mixer
=
MambaMixer2
(
hidden_size
=
config
.
hidden_size
,
self
.
mixer
=
MambaMixer2
(
hidden_size
=
config
.
hidden_size
,
...
@@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
head_dim
=
config
.
head_dim
,
rms_norm_eps
=
config
.
layer_norm_epsilon
,
rms_norm_eps
=
config
.
layer_norm_epsilon
,
activation
=
config
.
hidden_act
,
activation
=
config
.
hidden_act
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mixer"
,
chunk_size
=
config
.
chunk_size
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
...
@@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Mamba2DecoderLayer
(
config
,
lambda
prefix
:
Mamba2DecoderLayer
(
quant_config
=
quant_config
),
config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
...
@@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
)
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
...
@@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
),
i
-
self
.
start_layer
)
if
mamba_cache_params
else
None
,
mamba2_metadata
=
mamba2_metadata
)
mamba2_metadata
=
mamba2_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
...
@@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
return
loaded_params
return
loaded_params
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
...
@@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
num_mamba_layers
=
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
self
.
mamba_cache
=
MambaCacheManager
(
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
*
self
.
_get_mamba_cache_shape
())
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
else
:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
...
...
vllm/v1/attention/backends/mamba_attn.py
0 → 100644
View file @
a89209b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
_query_start_loc_to_chunk_indices_offsets
)
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
MambaSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
def
get_mamba2_chunk_size
(
vllm_config
:
VllmConfig
)
->
int
:
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
layers
=
get_layers_from_vllm_config
(
vllm_config
,
MambaMixer2
)
chunk_sizes
=
set
(
layer
.
chunk_size
for
layer
in
layers
.
values
())
assert
len
(
chunk_sizes
)
==
1
,
"All Mamba2 layers must have the same chunk size"
return
chunk_sizes
.
pop
()
class
Mamba2AttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_builder_cls
()
->
type
[
"Mamba2AttentionMetadataBuilder"
]:
return
Mamba2AttentionMetadataBuilder
@
dataclass
class
Mamba2AttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
has_initial_states
:
torch
.
Tensor
prep_initial_states
:
bool
chunk_size
:
int
seq_idx
:
torch
.
Tensor
chunk_indices
:
torch
.
Tensor
chunk_offsets
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
class
Mamba2AttentionMetadataBuilder
(
AttentionMetadataBuilder
[
Mamba2AttentionMetadata
]):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
MambaSpec
,
block_table
:
BlockTable
):
self
.
runner
=
runner
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
self
.
chunk_size
=
get_mamba2_chunk_size
(
runner
.
vllm_config
)
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes
=
[]
prefills
=
[]
num_decode_tokens
=
0
num_prefill_tokens
=
0
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if
num_tokens
==
1
:
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
else
:
prefills
.
append
(
i
)
num_prefill_tokens
+=
num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes
=
len
(
decodes
)
num_prefills
=
len
(
prefills
)
modified_batch
=
False
for
i
in
range
(
1
,
min
(
num_decodes
,
num_prefills
)
+
1
):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx
=
decodes
[
num_decodes
-
i
]
if
decode_idx
<
num_decodes
:
break
input_batch
.
swap_states
(
prefills
[
i
-
1
],
decode_idx
)
modified_batch
=
True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self
.
_num_decodes
=
num_decodes
self
.
_num_prefills
=
num_prefills
self
.
_num_decode_tokens
=
num_decode_tokens
self
.
_num_prefill_tokens
=
num_prefill_tokens
return
modified_batch
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_idx
=
None
chunk_indices
,
chunk_offsets
=
None
,
None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states
=
None
prep_initial_states
=
False
state_indices_tensor
=
self
.
block_table
.
block_table
[:
num_reqs
,
0
]
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if
self
.
_num_prefills
>
0
:
#[batch,]
has_initial_states_cpu
=
(
self
.
runner
.
input_batch
.
num_computed_tokens_cpu_tensor
[
num_reqs
-
self
.
_num_prefills
:
num_reqs
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_cpu
).
item
()
has_initial_states
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
-
self
.
_num_prefills
-
1
:]
-
self
.
_num_decode_tokens
seq_idx
=
torch
.
repeat_interleave
(
torch
.
arange
(
self
.
_num_prefills
,
dtype
=
torch
.
int32
,
device
=
query_start_loc_p
.
device
),
query_start_loc_p
.
diff
(),
output_size
=
self
.
_num_prefill_tokens
)
seq_idx
.
unsqueeze_
(
0
)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if
prep_initial_states
:
chunk_indices
,
chunk_offsets
=
(
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc_p
,
self
.
chunk_size
,
self
.
_num_prefill_tokens
))
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
self
.
_num_prefills
,
num_prefill_tokens
=
self
.
_num_prefill_tokens
,
num_decodes
=
self
.
_num_decodes
,
num_decode_tokens
=
self
.
_num_decode_tokens
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
has_initial_states
=
has_initial_states
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
state_indices_tensor
=
state_indices_tensor
,
)
return
attn_metadata
vllm/v1/core/single_type_kv_cache_manager.py
View file @
a89209b7
...
@@ -8,7 +8,7 @@ from vllm.utils import cdiv
...
@@ -8,7 +8,7 @@ from vllm.utils import cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
SlidingWindowSpec
)
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self
.
caching_hash_fn
=
caching_hash_fn
self
.
caching_hash_fn
=
caching_hash_fn
self
.
kv_cache_group_id
=
kv_cache_group_id
self
.
kv_cache_group_id
=
kv_cache_group_id
self
.
_null_block
=
block_pool
.
null_block
def
get_num_blocks_to_allocate
(
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
self
,
request_id
:
str
,
num_tokens
:
int
,
...
@@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
...
@@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return
0
return
0
class
MambaManager
(
SingleTypeKVCacheManager
):
@
classmethod
def
find_longest_cache_hit
(
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
assert
isinstance
(
kv_cache_spec
,
MambaSpec
),
(
"MambaManager can only be used for mamba groups"
)
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
# Each request will always have 1 block at this moment, so no need to
# remove blocks.
pass
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
return
0
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
KVCacheBlock
]:
new_blocks
=
super
().
allocate_new_blocks
(
request_id
,
num_tokens
)
assert
len
(
self
.
req_to_blocks
[
request_id
])
==
1
,
(
"MambaManager should only allocate 1 block for each request."
)
return
new_blocks
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
MambaSpec
:
MambaManager
,
}
}
...
...
vllm/v1/kv_cache_interface.py
View file @
a89209b7
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
copy
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
math
import
prod
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
...
@@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
class
MambaSpec
(
KVCacheSpec
):
shapes
:
tuple
[
tuple
[
int
,
...],
...]
dtype
:
torch
.
dtype
def
__post_init__
(
self
):
self
.
num_elements
=
sum
(
prod
(
shape
)
for
shape
in
self
.
shapes
)
@
property
def
type_id
(
self
)
->
str
:
return
f
"mamba_
{
self
.
shapes
}
_
{
self
.
dtype
}
"
@
property
def
page_size_bytes
(
self
)
->
int
:
return
self
.
num_elements
*
get_dtype_size
(
self
.
dtype
)
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return
self
.
page_size_bytes
@
dataclass
@
dataclass
class
KVCacheTensor
:
class
KVCacheTensor
:
"""
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a89209b7
...
@@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
...
@@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
)
set_forward_context
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
...
@@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
KVCacheConfig
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
ModelRunnerOutput
)
...
@@ -2093,9 +2096,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2093,9 +2096,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
i
,
kv_cache_group_spec
in
enumerate
(
for
i
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
kv_cache_config
.
kv_cache_groups
):
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
raise
NotImplementedError
(
"Only AttentionSpec is supported for now."
)
attn_backend_i
=
get_attn_backend
(
attn_backend_i
=
get_attn_backend
(
kv_cache_spec
.
head_size
,
kv_cache_spec
.
head_size
,
self
.
dtype
,
self
.
dtype
,
...
@@ -2105,8 +2106,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2105,8 +2106,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_mla
=
kv_cache_spec
.
use_mla
,
use_mla
=
kv_cache_spec
.
use_mla
,
)
)
if
attn_backend_i
is
None
:
if
attn_backend_i
is
None
:
error_msg
=
(
error_msg
=
(
f
"Error with get_attn_backend: "
f
"Error with get_attn_backend:
{
kv_cache_spec
.
head_size
=
}
, "
f
"
{
kv_cache_spec
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
kv_cache_spec
.
dtype
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
kv_cache_spec
.
dtype
=
}
, "
f
"
{
kv_cache_spec
.
block_size
=
}
, "
f
"
{
kv_cache_spec
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
...
@@ -2115,6 +2116,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2115,6 +2116,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 "
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
"GPUModelRunner."
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
attn_backend_i
=
Mamba2AttentionBackend
else
:
raise
ValueError
(
f
"Unknown KV cache spec type:
{
type
(
kv_cache_spec
)
}
"
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
...
@@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
*
inv_order
)
*
inv_order
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
]
dtype
=
kv_cache_spec
.
dtype
state_tensors
=
[]
start_pos
=
0
for
shape
in
kv_cache_spec
.
shapes
:
target_shape
=
(
num_blocks
,
*
shape
)
size_in_bytes
=
np
.
prod
(
shape
)
*
get_dtype_size
(
dtype
)
*
num_blocks
tensor
=
raw_tensor
[
start_pos
:
start_pos
+
size_in_bytes
]
tensor
=
tensor
.
view
(
dtype
).
view
(
target_shape
)
state_tensors
.
append
(
tensor
)
start_pos
+=
size_in_bytes
assert
start_pos
==
raw_tensor
.
numel
()
kv_caches
[
layer_name
]
=
tuple
(
state_tensors
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
kv_caches
return
kv_caches
...
@@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
format. Layers that do not need KV cache are not included.
"""
"""
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
layers
.
items
():
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
for
layer_name
,
attn_module
in
attn_layers
.
items
():
if
(
kv_tgt_layer
:
=
if
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
# The layer doesn't need its own KV cache and will use that of
# The layer doesn't need its own KV cache and will use that of
...
@@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
ValueError
(
raise
ValueError
(
f
"Unknown attention type:
{
attn_module
.
attn_type
}
"
)
f
"Unknown attention type:
{
attn_module
.
attn_type
}
"
)
mamba_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
MambaMixer2
)
if
len
(
mamba_layers
)
>
0
:
if
self
.
vllm_config
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
)
if
not
self
.
vllm_config
.
model_config
.
enforce_eager
:
raise
NotImplementedError
(
"Mamba with cuda graph is not supported yet."
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
raise
NotImplementedError
(
"Prefix caching is not supported for Mamba yet."
)
max_model_len
=
self
.
vllm_config
.
model_config
.
max_model_len
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for
layer_name
,
mamba_module
in
mamba_layers
.
items
():
kv_cache_spec
[
layer_name
]
=
MambaSpec
(
shapes
=
mamba_module
.
get_state_shape
(),
dtype
=
self
.
kv_cache_dtype
,
block_size
=
max_model_len
)
return
kv_cache_spec
return
kv_cache_spec
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment