Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a903669e
Unverified
Commit
a903669e
authored
Sep 23, 2025
by
Thomas Parnell
Committed by
GitHub
Sep 23, 2025
Browse files
[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
2c58742d
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
206 additions
and
1156 deletions
+206
-1156
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+19
-36
tests/models/registry.py
tests/models/registry.py
+6
-7
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+1
-4
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+40
-67
vllm/model_executor/layers/mamba/mamba2_metadata.py
vllm/model_executor/layers/mamba/mamba2_metadata.py
+0
-177
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+50
-99
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+56
-116
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+2
-18
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+0
-3
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+3
-18
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+3
-69
vllm/model_executor/models/constant_size_cache.py
vllm/model_executor/models/constant_size_cache.py
+0
-137
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+1
-64
vllm/model_executor/models/granitemoehybrid.py
vllm/model_executor/models/granitemoehybrid.py
+6
-73
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+8
-45
vllm/model_executor/models/lfm2.py
vllm/model_executor/models/lfm2.py
+0
-7
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+6
-39
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+5
-58
vllm/model_executor/models/mamba_cache.py
vllm/model_executor/models/mamba_cache.py
+0
-83
vllm/model_executor/models/minimax_cache.py
vllm/model_executor/models/minimax_cache.py
+0
-36
No files found.
tests/models/language/generation/test_hybrid.py
View file @
a903669e
...
...
@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"tiiuae/falcon-mamba-tiny-dev"
,
"yujiepan/mamba2-codestral-v0.1-tiny-random"
,
# mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
]
HYBRID_MODELS
=
[
...
...
@@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
]
V1_SUPPORTED_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"ai21labs/Jamba-tiny-dev"
,
"pfnet/plamo-2-1b"
,
"yujiepan/mamba2-codestral-v0.1-tiny-random"
,
"Zyphra/Zamba2-1.2B-instruct"
,
"hmellor/tiny-random-BambaForCausalLM"
,
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
"tiny-random/qwen3-next-moe"
,
]
FULL_CUDA_GRAPH_MODELS
=
[
...
...
@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct"
,
]
V0_UNSUPPORTED_MODELS
=
[
"LiquidAI/LFM2-1.2B"
,
]
FP32_STATE_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"Zyphra/Zamba2-1.2B-instruct"
,
...
...
@@ -88,20 +75,16 @@ def test_models(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
in
V1_SUPPORTED_MODELS
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v1_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
in
V1_SUPPORTED_MODELS
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm-v1"
,
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
...
...
@@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_
v1_
outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_
v1_
outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm
-v1
"
,
name_1
=
"vllm"
,
)
...
...
@@ -340,12 +323,12 @@ def test_fp32_cache_state(
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
**
{
cache_dtype_param
:
"float32"
})
as
vllm_model
:
vllm_
v1_
outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_
v1_
outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm
-v1
"
,
name_1
=
"vllm"
,
)
tests/models/registry.py
View file @
a903669e
...
...
@@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
"PhiForCausalLM"
:
_HfExamplesInfo
(
"microsoft/phi-2"
),
"Phi3ForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-mini-4k-instruct"
),
"Phi4FlashForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-4-mini-flash-reasoning"
,
# noqa: E501
trust_remote_code
=
True
,
v0_only
=
True
,
max_model_len
=
10240
),
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
trust_remote_code
=
True
),
"Plamo2ForCausalLM"
:
_HfExamplesInfo
(
"pfnet/plamo-2-1b"
,
trust_remote_code
=
True
),
max_transformers_version
=
"4.55.4"
,
transformers_version_reason
=
"HF model uses remote code that is not compatible with latest Transformers"
,
# noqa: E501
trust_remote_code
=
True
),
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
max_transformers_version
=
"4.53"
,
transformers_version_reason
=
"HF model uses remote code that is not compatible with latest Transformers"
,
# noqa: E501
...
...
@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
),
"Qwen3NextForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
min_transformers_version
=
"4.56.2"
),
extras
=
{
"tiny-random"
:
"tiny-random/qwen3-next-moe"
},
# noqa: E501
min_transformers_version
=
"4.56.3"
),
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
),
"SeedOssForCausalLM"
:
_HfExamplesInfo
(
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
# noqa: E501
trust_remote_code
=
True
,
...
...
@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
speculative_model
=
"XiaomiMiMo/MiMo-7B-RL"
),
"Qwen3NextMTP"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
min_transformers_version
=
"4.56.
2
"
),
min_transformers_version
=
"4.56.
3
"
),
}
_TRANSFORMERS_BACKEND_MODELS
=
{
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
a903669e
...
...
@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache
:
list
[
Iterable
[
torch
.
Tensor
]]
kv_cache
:
tuple
[
torch
.
Tensor
,
...]
@
abstractmethod
def
get_state_shape
(
self
)
->
Iterable
[
tuple
[
int
,
...]]:
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
a903669e
...
...
@@ -15,7 +15,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
...
...
@@ -42,8 +41,6 @@ if TYPE_CHECKING:
import
torch
import
torch.distributed
from
vllm.model_executor.models.minimax_cache
import
MinimaxCacheParams
class
MiniMaxText01RMSNormTP
(
CustomOp
):
name
=
"MiniMaxText01RMSNormTP"
...
...
@@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self
.
tp_heads
:(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
].
contiguous
()
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
@
staticmethod
def
weight_direct_load
(
param
:
torch
.
Tensor
,
...
...
@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
break
if
_prefill_idx
>=
len
(
state_indices_tensor
):
break
# prefills are packed at end of batch in V1
offset
=
attn_metadata
.
num_decode_tokens
if
envs
.
VLLM_USE_V1
else
0
offset
=
attn_metadata
.
num_decode_tokens
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
...
...
@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
if
envs
.
VLLM_USE_V1
:
hidden
.
insert
(
0
,
hidden_decode
)
else
:
hidden
.
append
(
hidden_decode
)
hidden
.
insert
(
0
,
hidden_decode
)
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
...
...
@@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
if
not
envs
.
VLLM_USE_V1
:
q
=
q
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
k
=
k
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
v
=
v
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
slot_id
=
state_indices_tensor
[
num_prefills
:]
else
:
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
hidden
=
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
slot_id
,
32
)
return
hidden
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
MinimaxCacheParams
)
->
None
:
if
not
envs
.
VLLM_USE_V1
:
self
.
_forward
(
hidden_states
,
output
,
positions
,
kv_caches
)
else
:
torch
.
ops
.
vllm
.
linear_attention
(
hidden_states
,
output
,
positions
,
self
.
prefix
,
)
positions
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
vllm
.
linear_attention
(
hidden_states
,
output
,
positions
,
self
.
prefix
,
)
def
_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
Optional
[
MinimaxCacheParams
])
->
None
:
positions
:
torch
.
Tensor
)
->
None
:
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
...
...
@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact
=
torch
.
nn
.
functional
.
silu
(
qkv32
)
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
else
:
assert
kv_caches
is
not
None
kv_cache
=
kv_caches
.
minimax_cache
state_indices_tensor
=
kv_caches
.
state_indices_tensor
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
attn_metadata
is
None
:
...
...
@@ -410,8 +384,7 @@ def linear_attention(
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_forward
(
hidden_states
=
hidden_states
,
output
=
output
,
positions
=
positions
,
kv_caches
=
None
)
positions
=
positions
)
def
linear_attention_fake
(
...
...
vllm/model_executor/layers/mamba/mamba2_metadata.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.placeholder_attn
import
(
PlaceholderAttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
(
Mamba2AttentionMetadata
,
_query_start_loc_to_chunk_indices_offsets
)
@
dataclass
class
Mamba2Metadata
:
prep_initial_states
:
bool
chunk_size
:
int
has_initial_states_p
:
torch
.
Tensor
seq_idx_p
:
torch
.
Tensor
chunk_indices_p
:
torch
.
Tensor
chunk_offsets_p
:
torch
.
Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
def
get_platform_metadata_classes
()
->
tuple
[
type
[
AttentionMetadata
],
...]:
"""Returns the appropriate metadata classes for the current platform."""
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.backends.rocm_aiter_fa
import
(
AiterFlashAttentionMetadata
)
from
vllm.v1.attention.backends.triton_attn
import
(
TritonAttentionMetadata
)
return
(
AiterFlashAttentionMetadata
,
TritonAttentionMetadata
,
PlaceholderAttentionMetadata
)
if
current_platform
.
is_cuda
():
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionMetadata
)
from
vllm.v1.attention.backends.xformers
import
(
XFormersAttentionMetadata
)
return
(
FlashAttentionMetadata
,
XFormersAttentionMetadata
,
PlaceholderAttentionMetadata
)
raise
ValueError
(
f
"Unsupported platform for Mamba2:
{
current_platform
.
device_type
}
"
)
def
prepare_mamba2_metadata
(
chunk_size
:
int
,
attn_metadata
:
AttentionMetadata
,
)
->
Mamba2Metadata
:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills
=
attn_metadata
.
num_prefills
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
seq_idx_p
=
None
chunk_indices_p
,
chunk_offsets_p
=
None
,
None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p
=
None
prep_initial_states
=
False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if
num_prefills
>
0
:
attn_metadata_instances
=
get_platform_metadata_classes
()
if
(
isinstance
(
attn_metadata
,
attn_metadata_instances
)
and
attn_metadata
.
context_lens_tensor
is
not
None
):
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states_p
=
(
attn_metadata
.
context_lens_tensor
[:
num_prefills
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_p
).
item
()
query_start_loc_p
=
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
seq_idx_p
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_prefills
,
dtype
=
torch
.
int32
,
device
=
query_start_loc_p
.
device
),
query_start_loc_p
.
diff
(),
output_size
=
num_prefill_tokens
)
seq_idx_p
.
unsqueeze_
(
0
)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if
prep_initial_states
:
chunk_indices_p
,
chunk_offsets_p
=
\
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc_p
,
chunk_size
,
num_prefill_tokens
)
return
Mamba2Metadata
(
has_initial_states_p
=
has_initial_states_p
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
chunk_size
,
seq_idx_p
=
seq_idx_p
,
chunk_indices_p
=
chunk_indices_p
,
chunk_offsets_p
=
chunk_offsets_p
)
def
update_metadata
(
x
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
mamba2_metadata
:
Union
[
Mamba2Metadata
,
Mamba2AttentionMetadata
,
GDNAttentionMetadata
]):
"""
this is triggered upon handling a new input at the first layer
"""
dim
,
cu_seqlen
=
x
.
shape
mamba2_metadata
.
cu_seqlen
=
cu_seqlen
seqlens
=
np
.
diff
(
query_start_loc
.
to
(
'cpu'
))
nums_dict
=
{}
# type: ignore
for
BLOCK_M
in
[
8
]:
# cover all BLOCK_M values
nums
=
-
(
-
seqlens
//
BLOCK_M
)
nums_dict
[
BLOCK_M
]
=
{}
nums_dict
[
BLOCK_M
][
'nums'
]
=
nums
nums_dict
[
BLOCK_M
][
'tot'
]
=
nums
.
sum
().
item
()
mlist
=
torch
.
from_numpy
(
np
.
repeat
(
np
.
arange
(
len
(
nums
)),
nums
))
nums_dict
[
BLOCK_M
][
'mlist'
]
=
mlist
mlist_len
=
len
(
nums_dict
[
BLOCK_M
][
'mlist'
])
nums_dict
[
BLOCK_M
][
'mlist_len'
]
=
mlist_len
MAX_NUM_PROGRAMS
=
max
(
1024
,
mlist_len
)
*
2
offsetlist
=
[]
# type: ignore
for
idx
,
num
in
enumerate
(
nums
):
offsetlist
.
extend
(
range
(
num
))
offsetlist
=
torch
.
tensor
(
offsetlist
,
dtype
=
torch
.
int32
)
nums_dict
[
BLOCK_M
][
'offsetlist'
]
=
offsetlist
if
mamba2_metadata
.
batch_ptr
is
None
:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata
.
batch_ptr
=
torch
.
full
((
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
mamba2_metadata
.
token_chunk_offset_ptr
=
torch
.
full
(
(
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
else
:
if
mamba2_metadata
.
batch_ptr
.
nelement
()
<
MAX_NUM_PROGRAMS
:
mamba2_metadata
.
batch_ptr
.
resize_
(
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
mamba2_metadata
.
token_chunk_offset_ptr
.
resize_
(
# type: ignore
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
mamba2_metadata
.
batch_ptr
[
0
:
mlist_len
].
copy_
(
mlist
)
mamba2_metadata
.
token_chunk_offset_ptr
[
# type: ignore
0
:
mlist_len
].
copy_
(
offsetlist
)
nums_dict
[
BLOCK_M
][
'batch_ptr'
]
=
mamba2_metadata
.
batch_ptr
nums_dict
[
BLOCK_M
][
'token_chunk_offset_ptr'
]
=
(
mamba2_metadata
.
token_chunk_offset_ptr
)
# type: ignore
mamba2_metadata
.
nums_dict
=
nums_dict
return
mamba2_metadata
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
a903669e
...
...
@@ -10,8 +10,6 @@ import torch
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
has_weight
=
rms_norm_has_weight
,
)
if
use_rms_norm
else
None
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
(
torch
.
tensor
([]),
torch
.
tensor
([]))
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
return
discrete_time_step
,
B
,
C
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
if
not
envs
.
VLLM_USE_V1
:
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
)
else
:
torch
.
ops
.
vllm
.
mamba_mixer
(
hidden_states
,
output
,
self
.
prefix
,
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
vllm
.
mamba_mixer
(
hidden_states
,
output
,
self
.
prefix
,
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
"""
Run the Mamba-1 SSM pipeline.
...
...
@@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba1_metadata
=
attn_metadata
assert
isinstance
(
mamba1_metadata
,
Mamba1AttentionMetadata
)
query_start_loc
=
mamba1_metadata
.
query_start_loc
state_indices_tensor
=
mamba1_metadata
.
state_indices_tensor
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
has_initial_states
=
mamba1_metadata
.
has_initial_states
num_padded_decodes
=
mamba1_metadata
.
num_padded_decodes
else
:
assert
isinstance
(
attn_metadata
,
AttentionMetadata
)
assert
mamba_cache_params
is
not
None
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
query_start_loc
=
attn_metadata
.
query_start_loc
context_lens_tensor
=
attn_metadata
.
context_lens_tensor
has_initial_states
=
None
if
context_lens_tensor
is
not
None
:
has_initial_states
=
context_lens_tensor
>
0
num_padded_decodes
=
attn_metadata
.
num_decode_tokens
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba1_metadata
=
attn_metadata
assert
isinstance
(
mamba1_metadata
,
Mamba1AttentionMetadata
)
query_start_loc
=
mamba1_metadata
.
query_start_loc
state_indices_tensor
=
mamba1_metadata
.
state_indices_tensor
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
has_initial_states
=
mamba1_metadata
.
has_initial_states
num_padded_decodes
=
mamba1_metadata
.
num_padded_decodes
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
...
@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# V1 profile run
hidden_states_BC
=
hidden_states_BC
.
contiguous
()
return
self
.
out_proj
(
hidden_states_BC
.
transpose
(
-
2
,
-
1
))[
0
]
...
...
@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
out
=
scan_outputs_d
)
scan_outputs_d
=
scan_outputs_d
.
transpose
(
0
,
1
)
if
envs
.
VLLM_USE_V1
:
ssm_outputs
.
insert
(
0
,
scan_outputs_d
)
else
:
ssm_outputs
.
append
(
scan_outputs_d
)
ssm_outputs
.
insert
(
0
,
scan_outputs_d
)
scan_outputs_combined
=
ssm_outputs
[
0
]
if
len
(
ssm_outputs
)
==
1
else
torch
.
cat
(
ssm_outputs
,
dim
=-
1
)
...
...
@@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
num_decodes
:
int
,
num_padded_decodes
:
int
,
)
->
PrefillDecodeSplit
:
num_actual_tokens
=
num_prefill_tokens
+
num_padded_decodes
if
envs
.
VLLM_USE_V1
:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d
,
hidden_states_BC_p
=
torch
.
split
(
hidden_states_BC
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_padded_decodes
+
num_prefills
],
[
num_padded_decodes
,
num_prefills
],
dim
=
0
)
query_start_loc_p
=
(
query_start_loc
[
-
num_prefills
-
1
:]
-
num_padded_decodes
if
num_prefills
>
0
else
None
)
has_initial_states_p
=
has_initial_states
[
-
num_prefills
:]
if
(
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
else
:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p
,
hidden_states_BC_d
=
torch
.
split
(
hidden_states_BC
,
[
num_prefill_tokens
,
num_decode_tokens
],
dim
=-
1
)
gate_p
,
gate_d
=
torch
.
split
(
gate
,
[
num_prefill_tokens
,
num_decode_tokens
],
dim
=-
1
)
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
)
query_start_loc_p
=
(
query_start_loc
[:
num_prefills
+
1
]
if
num_prefills
>
0
else
None
)
has_initial_states_p
=
has_initial_states
[:
num_prefills
]
if
(
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d
,
hidden_states_BC_p
=
torch
.
split
(
hidden_states_BC
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_padded_decodes
+
num_prefills
],
[
num_padded_decodes
,
num_prefills
],
dim
=
0
)
query_start_loc_p
=
(
query_start_loc
[
-
num_prefills
-
1
:]
-
num_padded_decodes
if
num_prefills
>
0
else
None
)
has_initial_states_p
=
has_initial_states
[
-
num_prefills
:]
if
(
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
return
PrefillDecodeSplit
(
hidden_states_BC_p
=
hidden_states_BC_p
,
...
...
@@ -495,9 +448,7 @@ def mamba_mixer(
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
mamba_cache_params
=
None
)
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
def
mamba_mixer_fake
(
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
a903669e
...
...
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import
torch
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
...
@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
update_metadata
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
...
...
@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
(
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self
.
use_rms_norm
,
eps
=
rms_norm_eps
)
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The tuple is (conv_state, ssm_state)
self
.
kv_cache
=
(
torch
.
tensor
([]),
torch
.
tensor
([]))
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
pass
...
...
@@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
not
envs
.
VLLM_USE_V1
:
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
,
mup_vector
)
else
:
torch
.
ops
.
vllm
.
mamba_mixer2
(
hidden_states
,
output
,
self
.
prefix
,
mup_vector
,
)
torch
.
ops
.
vllm
.
mamba_mixer2
(
hidden_states
,
output
,
self
.
prefix
,
mup_vector
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
forward_context
=
get_forward_context
()
#
mamba2
_metadata contains metadata necessary for the mamba2 triton
#
attn
_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba2_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
else
:
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
# Common members between V1 metadata and V0 metadata
if
mamba2_metadata
is
not
None
:
has_initial_states_p
=
mamba2_metadata
.
has_initial_states_p
prep_initial_states
=
mamba2_metadata
.
prep_initial_states
chunk_size
=
mamba2_metadata
.
chunk_size
seq_idx_p
=
mamba2_metadata
.
seq_idx_p
chunk_indices_p
=
mamba2_metadata
.
chunk_indices_p
chunk_offsets_p
=
mamba2_metadata
.
chunk_offsets_p
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx_p
chunk_indices_p
=
attn_metadata
.
chunk_indices_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
...
...
@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim
=-
1
,
)
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
#
V1
profile run
if
attn_metadata
is
None
:
# profile run
hidden_states_B_C
=
(
hidden_states_B_C
.
transpose
(
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
hidden_states
,
_B
,
_C
=
split_hidden_states_B_C_fn
(
...
...
@@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode
=
num_decodes
>
0
num_actual_tokens
=
num_prefill_tokens
+
num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if
envs
.
VLLM_USE_V1
:
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
dt_d
,
dt_p
=
torch
.
split
(
dt
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_actual_tokens
],
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
else
:
hidden_states_B_C_p
,
hidden_states_B_C_d
=
torch
.
split
(
hidden_states_B_C
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
dt_p
,
dt_d
=
torch
.
split
(
dt
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
if
has_prefill
else
None
)
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
dt_d
,
dt_p
=
torch
.
split
(
dt
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_actual_tokens
],
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
...
...
@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
if
envs
.
VLLM_USE_V1
:
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
else
:
preallocated_ssm_out_p
,
preallocated_ssm_out_d
=
torch
.
split
(
preallocated_ssm_out
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Process prefill requests
if
has_prefill
:
...
...
@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x
=
hidden_states_B_C_p
.
transpose
(
0
,
1
)
# this is the form that causal-conv see
if
mamba2_metadata
.
cu_seqlen
is
None
:
mamba2_metadata
=
update_metadata
(
x
,
query_start_loc_p
,
mamba2_metadata
)
hidden_states_B_C_p
=
causal_conv1d_fn
(
x
,
conv_weights
,
...
...
@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
mamba2
_metadata
,
metadata
=
attn
_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
...
...
@@ -806,8 +748,6 @@ def mamba_mixer2(
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
mamba_cache_params
=
None
,
mamba2_metadata
=
None
,
mup_vector
=
mup_vector
)
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
a903669e
...
...
@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
intermediate_size
:
int
,
state_size
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
conv_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
conv_kernel
-
1
)
...
...
@@ -108,11 +107,7 @@ class MambaStateShapeCalculator:
temporal_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
state_size
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
conv_state_shape
,
temporal_state_shape
...
...
@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
head_dim
:
int
,
state_size
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
...
...
@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
# contiguous along 'dim' axis
conv_state_shape
=
(
conv_kernel
-
1
,
divide
(
conv_dim
,
tp_world_size
))
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
...
...
@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
tp_world_size
:
int
,
intermediate_size
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
(
conv_state_shape
,
)
@
classmethod
...
...
@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
head_v_dim
:
int
,
conv_kernel_size
:
int
,
num_spec
:
int
=
0
,
use_v1
:
bool
=
True
,
):
conv_dim
=
(
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
)
conv_state_shape
=
(
...
...
@@ -191,11 +179,7 @@ class MambaStateShapeCalculator:
conv_kernel_size
-
1
+
num_spec
,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
temporal_state_shape
=
(
divide
(
num_v_heads
,
tp_world_size
),
head_k_dim
,
head_v_dim
)
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
a903669e
...
...
@@ -420,9 +420,7 @@ def causal_conv1d_fn(
x
=
x
.
to
(
conv_states
.
dtype
)
out
=
torch
.
empty_like
(
x
)
if
metadata
is
not
None
:
cu_seqlen
=
metadata
.
cu_seqlen
nums_dict
=
metadata
.
nums_dict
#x = metadata.x
args
=
nums_dict
batch_ptr
=
metadata
.
batch_ptr
token_chunk_offset_ptr
=
metadata
.
token_chunk_offset_ptr
...
...
@@ -926,7 +924,6 @@ def causal_conv1d_update(
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
max_query_len
:
int
=
-
1
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
metadata
=
None
,
validate_data
=
False
,
):
"""
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
a903669e
...
...
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
...
...
@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
assert
envs
.
VLLM_USE_V1
,
(
"ShortConv layers are only supported in V1"
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self
.
kv_cache
=
[(
torch
.
tensor
([]),
)]
self
.
kv_cache
=
(
torch
.
tensor
([]),
)
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
return
...
...
@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
torch
.
ops
.
vllm
.
short_conv
(
hidden_states
,
...
...
@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
forward_context
=
get_forward_context
()
# ShortConvAttentionMetadata contains metadata necessary for the
...
...
@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
...
...
@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
if
has_prefill
:
Bx_p
=
(
B_p
*
x_p
).
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
Bx_p
,
query_start_loc_p
,
conv_metadata
)
Bx
=
causal_conv1d_fn
(
Bx_p
,
conv_weights
,
self
.
conv
.
bias
,
...
...
@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
conv
_metadata
,
metadata
=
attn
_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
...
...
@@ -248,9 +235,7 @@ def short_conv(
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
conv_metadata
=
None
)
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
def
short_conv_fake
(
...
...
vllm/model_executor/models/bamba.py
View file @
a903669e
...
...
@@ -9,21 +9,17 @@ import torch
from
torch
import
nn
from
transformers
import
BambaConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
)
...
...
@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
...
@@ -315,22 +306,10 @@ class BambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
...
...
@@ -343,23 +322,11 @@ class BambaModel(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
None
num_attn
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
layer
,
BambaAttentionDecoderLayer
):
num_attn
+=
1
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
BambaMixerDecoderLayer
)
and
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
num_attn
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
\
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/constant_size_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
class
ConstantSizeCache
(
ABC
):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def
__init__
(
self
,
max_batch_size
:
int
):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self
.
cache_indices_mapping
:
dict
[
str
,
dict
[
int
,
int
]]
=
{}
self
.
free_cache_indices
=
list
(
range
(
max_batch_size
))
@
property
@
abstractmethod
def
cache
(
self
)
->
Any
:
"""Return the underlying cache tensor(s)"""
pass
@
abstractmethod
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
"""Copy cache data from one index to another"""
pass
def
current_run_tensors
(
self
,
**
kwargs
)
->
tuple
:
"""
Return the tensors for the current run's conv and ssm state.
"""
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices_tensor
=
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cache_tensors
=
self
.
cache
else
:
# CUDA graph capturing runs
cache_tensors
,
state_indices_tensor
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
return
(
cache_tensors
,
state_indices_tensor
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
assert
"seqlen_agnostic_capture_inputs"
in
input_buffers
_
,
input_state_indices_buffer
=
input_buffers
[
"seqlen_agnostic_capture_inputs"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
cuda_graph_pad_len
=
input_state_indices_buffer
.
shape
[
0
]
-
len
(
state_indices
)
state_indices
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_len
)
input_state_indices_buffer
.
copy_
(
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
))
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor
=
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
(
self
.
cache
,
state_indices_tensor
)
def
_assign_seq_id_to_cache_index
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
finished_requests_ids
)
->
int
:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if
cur_rid
in
finished_requests_ids
:
# set as pad, do not allocate destination index
return
PAD_SLOT_ID
elif
cur_rid
not
in
self
.
cache_indices_mapping
:
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
return
destination_index
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists
=
next
(
iter
(
seq_ids2indices
.
values
()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
_copy_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
return
destination_index
else
:
return
self
.
cache_indices_mapping
[
cur_rid
][
seq_id
]
def
_prepare_current_run_cache
(
self
,
request_ids_to_seq_ids
:
dict
[
str
,
list
[
int
]],
finished_requests_ids
:
list
[
str
])
->
list
[
int
]:
return
[
self
.
_assign_seq_id_to_cache_index
(
req_id
,
seq_id
,
finished_requests_ids
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
def
_release_finished_requests
(
self
,
finished_seq_groups_req_ids
:
list
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
cache_indices_mapping
:
for
seq_id
in
self
.
cache_indices_mapping
[
req_id
]:
self
.
free_cache_indices
.
append
(
self
.
cache_indices_mapping
[
req_id
][
seq_id
])
self
.
cache_indices_mapping
.
pop
(
req_id
)
vllm/model_executor/models/falcon_h1.py
View file @
a903669e
...
...
@@ -8,21 +8,17 @@ import torch
from
torch
import
nn
from
transformers
import
FalconH1Config
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
...
...
@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
mup_vector
=
self
.
mup_vector
,
)
return
output
,
residual
...
...
@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
residual
=
hidden_states
...
...
@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
# Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual,
mamba_cache_params,
and sequence_idx.
# residual, and sequence_idx.
ssm_hidden
,
_
=
self
.
mamba
(
hidden_states
=
hidden_states
*
self
.
ssm_in_multiplier
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
**
kwargs
,
)
# Sum the outputs from both branches.
...
...
@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
*
self
.
embedding_multiplier
...
...
@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer_mamba_cache_params
=
None
if
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
)
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
tie_word_embeddings
=
config
.
tie_word_embeddings
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
if
get_pp_group
().
is_last_rank
:
...
...
@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
**
kwargs
,
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
self
.
config
.
num_hidden_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
,
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
,
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/granitemoehybrid.py
View file @
a903669e
...
...
@@ -9,19 +9,15 @@ import torch
from
torch
import
nn
from
transformers
import
GraniteMoeHybridConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.granitemoe
import
GraniteMoeMoE
from
.granitemoeshared
import
GraniteMoeSharedMLP
...
...
@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
)
hidden_states
=
residual
+
output
*
self
.
residual_multiplier
residual
=
hidden_states
...
...
@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
...
...
@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
layer
,
GraniteMoeHybridAttentionDecoderLayer
):
num_attn
+=
1
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
GraniteMoeHybridMambaDecoderLayer
)
and
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
num_attn
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
scale
=
1
/
self
.
config
.
logits_scaling
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/jamba.py
View file @
a903669e
...
...
@@ -9,7 +9,6 @@ import torch
from
torch
import
nn
from
transformers
import
JambaConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
...
...
@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaMLP
as
JambaMLP
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
...
...
@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
)
self
.
mamba
(
hidden_states
,
output
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
...
@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
kv_cache_index
=
0
mamba_cache_index
=
0
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache_index
+=
1
if
isinstance
(
layer
,
JambaMambaDecoderLayer
)
and
mamba_cache_params
:
current_state_layer
=
mamba_cache_index
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
current_state_layer
)
mamba_cache_index
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
state_shape
=
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
)
state_dtype
=
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_layers
,
*
state_shape
,
*
state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
...
@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size
=
hf_config
.
mamba_expand
*
hidden_size
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
envs
.
VLLM_USE_V1
,
)
def
compute_logits
(
...
...
vllm/model_executor/models/lfm2.py
View file @
a903669e
...
...
@@ -8,7 +8,6 @@ import torch
import
torch.nn
as
nn
from
transformers
import
Lfm2Config
from
vllm
import
envs
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
...
...
@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self
.
conv
(
hidden_states
,
output
,
conv_metadata
=
None
,
)
hidden_states
,
residual
=
self
.
ffn_norm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
...
@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
""" Calculate shapes for LFM2's convolutional cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
intermediate_size
=
hf_config
.
conv_dim
,
conv_kernel
=
hf_config
.
conv_L_cache
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
...
...
@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config
=
vllm_config
.
scheduler_config
assert
(
not
cache_config
.
enable_prefix_caching
),
"Lfm2 currently does not support prefix caching"
assert
envs
.
VLLM_USE_V1
,
(
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1"
)
super
().
__init__
()
self
.
config
=
config
...
...
vllm/model_executor/models/mamba.py
View file @
a903669e
...
...
@@ -8,7 +8,6 @@ import torch
from
torch
import
nn
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
...
...
@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
,
SupportsPP
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
...
@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
)
self
.
mixer
(
hidden_states
,
output
)
return
output
,
residual
...
...
@@ -134,7 +129,6 @@ class MambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -151,17 +145,9 @@ class MambaModel(nn.Module):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer_cache_params
=
None
if
mamba_cache_params
is
not
None
:
layer_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_cache_params
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
state_shape
=
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
)
state_dtype
=
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_layers
,
*
state_shape
,
*
state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
intermediate_size
=
hf_config
.
intermediate_size
,
state_size
=
hf_config
.
state_size
,
conv_kernel
=
hf_config
.
conv_kernel
,
use_v1
=
envs
.
VLLM_USE_V1
)
conv_kernel
=
hf_config
.
conv_kernel
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
...
...
vllm/model_executor/models/mamba2.py
View file @
a903669e
...
...
@@ -8,16 +8,11 @@ import torch
from
torch
import
nn
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
...
@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mixer
(
hidden_states
,
output
)
return
output
,
residual
...
...
@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
)
if
mamba_cache_params
else
None
,
mamba2_metadata
=
mamba2_metadata
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim
=
hf_config
.
head_dim
,
state_size
=
hf_config
.
state_size
,
conv_kernel
=
hf_config
.
conv_kernel
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
backbone
.
embeddings
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
else
:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/mamba_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.model_executor.models.constant_size_cache
import
ConstantSizeCache
@
dataclass
class
MambaCacheParams
:
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MambaCacheParams
(
self
.
conv_state
[
layer_idx
],
self
.
ssm_state
[
layer_idx
],
self
.
state_indices_tensor
)
class
MambaCacheManager
(
ConstantSizeCache
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
num_mamba_layers
:
int
,
conv_state_shape
:
tuple
[
int
,
int
],
temporal_state_shape
:
tuple
[
int
,
int
],
conv_state_dtype
:
torch
.
dtype
,
temporal_state_dtype
:
torch
.
dtype
):
self
.
conv_state_dtype
=
conv_state_dtype
self
.
temporal_state_dtype
=
temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
vllm_config
.
model_config
.
enforce_eager
:
max_batch_size
=
vllm_config
.
pad_for_cudagraph
(
max_batch_size
)
# Initialize parent class
super
().
__init__
(
max_batch_size
)
# assume conv_state = (dim, state_len)
assert
conv_state_shape
[
0
]
>
conv_state_shape
[
1
]
conv_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
(
conv_state_shape
[
1
],
conv_state_shape
[
0
]),
dtype
=
self
.
conv_state_dtype
,
device
=
"cuda"
).
transpose
(
-
1
,
-
2
)
temporal_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
self
.
temporal_state_dtype
,
device
=
"cuda"
)
self
.
_mamba_cache
=
(
conv_state
,
temporal_state
)
@
property
def
cache
(
self
):
return
self
.
_mamba_cache
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
for
cache_t
in
self
.
cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
current_run_tensors
(
self
,
**
kwargs
)
->
MambaCacheParams
:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors
,
state_indices_tensor
=
super
().
current_run_tensors
(
**
kwargs
)
return
MambaCacheParams
(
cache_tensors
[
0
],
cache_tensors
[
1
],
state_indices_tensor
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
self
.
_mamba_cache
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
vllm/model_executor/models/minimax_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.model_executor.models.constant_size_cache
import
ConstantSizeCache
@
dataclass
class
MinimaxCacheParams
:
minimax_cache
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MinimaxCacheParams
(
self
.
minimax_cache
[
layer_idx
,
...],
self
.
state_indices_tensor
)
class
MinimaxCacheManager
(
ConstantSizeCache
):
def
__init__
(
self
,
dtype
,
cache_shape
):
super
().
__init__
(
cache_shape
[
1
])
# max_batch_size is cache_shape[1]
self
.
_minimax_cache
=
torch
.
empty
(
size
=
cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
property
def
cache
(
self
):
return
self
.
_minimax_cache
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
cache
)
>
0
for
cache_t
in
self
.
cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment