Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a903669e
Unverified
Commit
a903669e
authored
Sep 23, 2025
by
Thomas Parnell
Committed by
GitHub
Sep 23, 2025
Browse files
[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
2c58742d
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
206 additions
and
1156 deletions
+206
-1156
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+19
-36
tests/models/registry.py
tests/models/registry.py
+6
-7
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+1
-4
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+40
-67
vllm/model_executor/layers/mamba/mamba2_metadata.py
vllm/model_executor/layers/mamba/mamba2_metadata.py
+0
-177
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+50
-99
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+56
-116
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+2
-18
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+0
-3
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+3
-18
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+3
-69
vllm/model_executor/models/constant_size_cache.py
vllm/model_executor/models/constant_size_cache.py
+0
-137
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+1
-64
vllm/model_executor/models/granitemoehybrid.py
vllm/model_executor/models/granitemoehybrid.py
+6
-73
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+8
-45
vllm/model_executor/models/lfm2.py
vllm/model_executor/models/lfm2.py
+0
-7
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+6
-39
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+5
-58
vllm/model_executor/models/mamba_cache.py
vllm/model_executor/models/mamba_cache.py
+0
-83
vllm/model_executor/models/minimax_cache.py
vllm/model_executor/models/minimax_cache.py
+0
-36
No files found.
tests/models/language/generation/test_hybrid.py
View file @
a903669e
...
@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
...
@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS
=
[
SSM_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"state-spaces/mamba-130m-hf"
,
"tiiuae/falcon-mamba-tiny-dev"
,
"tiiuae/falcon-mamba-tiny-dev"
,
"yujiepan/mamba2-codestral-v0.1-tiny-random"
,
# mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
]
]
HYBRID_MODELS
=
[
HYBRID_MODELS
=
[
...
@@ -31,18 +33,7 @@ HYBRID_MODELS = [
...
@@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview"
,
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
"LiquidAI/LFM2-1.2B"
,
]
"tiny-random/qwen3-next-moe"
,
V1_SUPPORTED_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"ai21labs/Jamba-tiny-dev"
,
"pfnet/plamo-2-1b"
,
"yujiepan/mamba2-codestral-v0.1-tiny-random"
,
"Zyphra/Zamba2-1.2B-instruct"
,
"hmellor/tiny-random-BambaForCausalLM"
,
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
]
]
FULL_CUDA_GRAPH_MODELS
=
[
FULL_CUDA_GRAPH_MODELS
=
[
...
@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
...
@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct"
,
"Zyphra/Zamba2-1.2B-instruct"
,
]
]
V0_UNSUPPORTED_MODELS
=
[
"LiquidAI/LFM2-1.2B"
,
]
FP32_STATE_MODELS
=
[
FP32_STATE_MODELS
=
[
"state-spaces/mamba-130m-hf"
,
"state-spaces/mamba-130m-hf"
,
"Zyphra/Zamba2-1.2B-instruct"
,
"Zyphra/Zamba2-1.2B-instruct"
,
...
@@ -88,20 +75,16 @@ def test_models(
...
@@ -88,20 +75,16 @@ def test_models(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
in
V1_SUPPORTED_MODELS
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v1_outputs
=
None
if
model
in
V1_SUPPORTED_MODELS
:
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm-v1"
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
...
@@ -299,14 +282,14 @@ def test_full_cuda_graph(
...
@@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_
v1_
outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_
v1_
outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm
-v1
"
,
name_1
=
"vllm"
,
)
)
...
@@ -340,12 +323,12 @@ def test_fp32_cache_state(
...
@@ -340,12 +323,12 @@ def test_fp32_cache_state(
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
max_num_seqs
=
MAX_NUM_SEQS
,
**
{
cache_dtype_param
:
"float32"
})
as
vllm_model
:
**
{
cache_dtype_param
:
"float32"
})
as
vllm_model
:
vllm_
v1_
outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_
v1_
outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm
-v1
"
,
name_1
=
"vllm"
,
)
)
tests/models/registry.py
View file @
a903669e
...
@@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
"PhiForCausalLM"
:
_HfExamplesInfo
(
"microsoft/phi-2"
),
"PhiForCausalLM"
:
_HfExamplesInfo
(
"microsoft/phi-2"
),
"Phi3ForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-mini-4k-instruct"
),
"Phi3ForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-mini-4k-instruct"
),
"Phi4FlashForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-4-mini-flash-reasoning"
,
# noqa: E501
trust_remote_code
=
True
,
v0_only
=
True
,
max_model_len
=
10240
),
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Plamo2ForCausalLM"
:
_HfExamplesInfo
(
"pfnet/plamo-2-1b"
,
"Plamo2ForCausalLM"
:
_HfExamplesInfo
(
"pfnet/plamo-2-1b"
,
trust_remote_code
=
True
),
max_transformers_version
=
"4.55.4"
,
transformers_version_reason
=
"HF model uses remote code that is not compatible with latest Transformers"
,
# noqa: E501
trust_remote_code
=
True
),
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
max_transformers_version
=
"4.53"
,
max_transformers_version
=
"4.53"
,
transformers_version_reason
=
"HF model uses remote code that is not compatible with latest Transformers"
,
# noqa: E501
transformers_version_reason
=
"HF model uses remote code that is not compatible with latest Transformers"
,
# noqa: E501
...
@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
),
"Qwen3MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
),
"Qwen3NextForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"Qwen3NextForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
min_transformers_version
=
"4.56.2"
),
extras
=
{
"tiny-random"
:
"tiny-random/qwen3-next-moe"
},
# noqa: E501
min_transformers_version
=
"4.56.3"
),
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
),
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
),
"SeedOssForCausalLM"
:
_HfExamplesInfo
(
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
# noqa: E501
"SeedOssForCausalLM"
:
_HfExamplesInfo
(
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
# noqa: E501
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
trust_remote_code
=
True
,
speculative_model
=
"XiaomiMiMo/MiMo-7B-RL"
),
speculative_model
=
"XiaomiMiMo/MiMo-7B-RL"
),
"Qwen3NextMTP"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"Qwen3NextMTP"
:
_HfExamplesInfo
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
min_transformers_version
=
"4.56.
2
"
),
min_transformers_version
=
"4.56.
3
"
),
}
}
_TRANSFORMERS_BACKEND_MODELS
=
{
_TRANSFORMERS_BACKEND_MODELS
=
{
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
a903669e
...
@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
...
@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
# Contains the KV cache (mamba state) for the layer
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
# in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path
kv_cache
:
tuple
[
torch
.
Tensor
,
...]
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache
:
list
[
Iterable
[
torch
.
Tensor
]]
@
abstractmethod
@
abstractmethod
def
get_state_shape
(
self
)
->
Iterable
[
tuple
[
int
,
...]]:
def
get_state_shape
(
self
)
->
Iterable
[
tuple
[
int
,
...]]:
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
a903669e
...
@@ -15,7 +15,6 @@ import torch.nn.functional as F
...
@@ -15,7 +15,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
nn
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
...
@@ -42,8 +41,6 @@ if TYPE_CHECKING:
...
@@ -42,8 +41,6 @@ if TYPE_CHECKING:
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
vllm.model_executor.models.minimax_cache
import
MinimaxCacheParams
class
MiniMaxText01RMSNormTP
(
CustomOp
):
class
MiniMaxText01RMSNormTP
(
CustomOp
):
name
=
"MiniMaxText01RMSNormTP"
name
=
"MiniMaxText01RMSNormTP"
...
@@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self
.
tp_heads
:(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
:(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
].
contiguous
()
self
.
tp_heads
].
contiguous
()
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
@
staticmethod
@
staticmethod
def
weight_direct_load
(
param
:
torch
.
Tensor
,
def
weight_direct_load
(
param
:
torch
.
Tensor
,
...
@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
break
break
if
_prefill_idx
>=
len
(
state_indices_tensor
):
if
_prefill_idx
>=
len
(
state_indices_tensor
):
break
break
# prefills are packed at end of batch in V1
offset
=
attn_metadata
.
num_decode_tokens
offset
=
attn_metadata
.
num_decode_tokens
if
envs
.
VLLM_USE_V1
else
0
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
...
@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
state_indices_tensor
,
attn_metadata
)
attn_metadata
)
if
envs
.
VLLM_USE_V1
:
hidden
.
insert
(
0
,
hidden_decode
)
hidden
.
insert
(
0
,
hidden_decode
)
else
:
hidden
.
append
(
hidden_decode
)
if
not
hidden
:
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
...
@@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
attn_metadata
):
if
not
envs
.
VLLM_USE_V1
:
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
q
=
q
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
slot_id
=
state_indices_tensor
[
num_prefills
:]
else
:
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
hidden
=
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
hidden
=
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
slot_id
,
32
)
slot_id
,
32
)
return
hidden
return
hidden
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
None
:
kv_caches
:
MinimaxCacheParams
)
->
None
:
torch
.
ops
.
vllm
.
linear_attention
(
if
not
envs
.
VLLM_USE_V1
:
hidden_states
,
self
.
_forward
(
hidden_states
,
output
,
positions
,
kv_caches
)
output
,
else
:
positions
,
torch
.
ops
.
vllm
.
linear_attention
(
self
.
prefix
,
hidden_states
,
)
output
,
positions
,
self
.
prefix
,
)
def
_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
def
_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
None
:
kv_caches
:
Optional
[
MinimaxCacheParams
])
->
None
:
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
...
@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact
=
torch
.
nn
.
functional
.
silu
(
qkv32
)
qkvact
=
torch
.
nn
.
functional
.
silu
(
qkv32
)
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
>
0
:
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
num_decode_tokens
=
getattr
(
attn_metadata
,
0
)
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
q_start
=
attn_metadata
.
query_start_loc
[
prefill_idx
]
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
prefill_idx
+
1
]
+
prefill_idx
+
query_len
=
q_end
-
q_start
1
]
context_len
=
attn_metadata
.
seq_lens
[
query_len
=
q_end
-
q_start
num_decode_tokens
+
prefill_idx
]
-
query_len
context_len
=
attn_metadata
.
seq_lens
[
if
context_len
==
0
:
num_decode_tokens
+
prefill_idx
]
-
query_len
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
if
context_len
==
0
:
+
prefill_idx
]
block_to_clear
=
state_indices_tensor
[
kv_cache
[
block_to_clear
,
...]
=
0
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
else
:
assert
kv_caches
is
not
None
kv_cache
=
kv_caches
.
minimax_cache
state_indices_tensor
=
kv_caches
.
state_indices_tensor
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
...
@@ -410,8 +384,7 @@ def linear_attention(
...
@@ -410,8 +384,7 @@ def linear_attention(
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_forward
(
hidden_states
=
hidden_states
,
self
.
_forward
(
hidden_states
=
hidden_states
,
output
=
output
,
output
=
output
,
positions
=
positions
,
positions
=
positions
)
kv_caches
=
None
)
def
linear_attention_fake
(
def
linear_attention_fake
(
...
...
vllm/model_executor/layers/mamba/mamba2_metadata.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.placeholder_attn
import
(
PlaceholderAttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
(
Mamba2AttentionMetadata
,
_query_start_loc_to_chunk_indices_offsets
)
@
dataclass
class
Mamba2Metadata
:
prep_initial_states
:
bool
chunk_size
:
int
has_initial_states_p
:
torch
.
Tensor
seq_idx_p
:
torch
.
Tensor
chunk_indices_p
:
torch
.
Tensor
chunk_offsets_p
:
torch
.
Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
def
get_platform_metadata_classes
()
->
tuple
[
type
[
AttentionMetadata
],
...]:
"""Returns the appropriate metadata classes for the current platform."""
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.backends.rocm_aiter_fa
import
(
AiterFlashAttentionMetadata
)
from
vllm.v1.attention.backends.triton_attn
import
(
TritonAttentionMetadata
)
return
(
AiterFlashAttentionMetadata
,
TritonAttentionMetadata
,
PlaceholderAttentionMetadata
)
if
current_platform
.
is_cuda
():
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionMetadata
)
from
vllm.v1.attention.backends.xformers
import
(
XFormersAttentionMetadata
)
return
(
FlashAttentionMetadata
,
XFormersAttentionMetadata
,
PlaceholderAttentionMetadata
)
raise
ValueError
(
f
"Unsupported platform for Mamba2:
{
current_platform
.
device_type
}
"
)
def
prepare_mamba2_metadata
(
chunk_size
:
int
,
attn_metadata
:
AttentionMetadata
,
)
->
Mamba2Metadata
:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills
=
attn_metadata
.
num_prefills
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
seq_idx_p
=
None
chunk_indices_p
,
chunk_offsets_p
=
None
,
None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p
=
None
prep_initial_states
=
False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if
num_prefills
>
0
:
attn_metadata_instances
=
get_platform_metadata_classes
()
if
(
isinstance
(
attn_metadata
,
attn_metadata_instances
)
and
attn_metadata
.
context_lens_tensor
is
not
None
):
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states_p
=
(
attn_metadata
.
context_lens_tensor
[:
num_prefills
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_p
).
item
()
query_start_loc_p
=
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
seq_idx_p
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_prefills
,
dtype
=
torch
.
int32
,
device
=
query_start_loc_p
.
device
),
query_start_loc_p
.
diff
(),
output_size
=
num_prefill_tokens
)
seq_idx_p
.
unsqueeze_
(
0
)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if
prep_initial_states
:
chunk_indices_p
,
chunk_offsets_p
=
\
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc_p
,
chunk_size
,
num_prefill_tokens
)
return
Mamba2Metadata
(
has_initial_states_p
=
has_initial_states_p
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
chunk_size
,
seq_idx_p
=
seq_idx_p
,
chunk_indices_p
=
chunk_indices_p
,
chunk_offsets_p
=
chunk_offsets_p
)
def
update_metadata
(
x
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
mamba2_metadata
:
Union
[
Mamba2Metadata
,
Mamba2AttentionMetadata
,
GDNAttentionMetadata
]):
"""
this is triggered upon handling a new input at the first layer
"""
dim
,
cu_seqlen
=
x
.
shape
mamba2_metadata
.
cu_seqlen
=
cu_seqlen
seqlens
=
np
.
diff
(
query_start_loc
.
to
(
'cpu'
))
nums_dict
=
{}
# type: ignore
for
BLOCK_M
in
[
8
]:
# cover all BLOCK_M values
nums
=
-
(
-
seqlens
//
BLOCK_M
)
nums_dict
[
BLOCK_M
]
=
{}
nums_dict
[
BLOCK_M
][
'nums'
]
=
nums
nums_dict
[
BLOCK_M
][
'tot'
]
=
nums
.
sum
().
item
()
mlist
=
torch
.
from_numpy
(
np
.
repeat
(
np
.
arange
(
len
(
nums
)),
nums
))
nums_dict
[
BLOCK_M
][
'mlist'
]
=
mlist
mlist_len
=
len
(
nums_dict
[
BLOCK_M
][
'mlist'
])
nums_dict
[
BLOCK_M
][
'mlist_len'
]
=
mlist_len
MAX_NUM_PROGRAMS
=
max
(
1024
,
mlist_len
)
*
2
offsetlist
=
[]
# type: ignore
for
idx
,
num
in
enumerate
(
nums
):
offsetlist
.
extend
(
range
(
num
))
offsetlist
=
torch
.
tensor
(
offsetlist
,
dtype
=
torch
.
int32
)
nums_dict
[
BLOCK_M
][
'offsetlist'
]
=
offsetlist
if
mamba2_metadata
.
batch_ptr
is
None
:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata
.
batch_ptr
=
torch
.
full
((
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
mamba2_metadata
.
token_chunk_offset_ptr
=
torch
.
full
(
(
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
else
:
if
mamba2_metadata
.
batch_ptr
.
nelement
()
<
MAX_NUM_PROGRAMS
:
mamba2_metadata
.
batch_ptr
.
resize_
(
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
mamba2_metadata
.
token_chunk_offset_ptr
.
resize_
(
# type: ignore
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
mamba2_metadata
.
batch_ptr
[
0
:
mlist_len
].
copy_
(
mlist
)
mamba2_metadata
.
token_chunk_offset_ptr
[
# type: ignore
0
:
mlist_len
].
copy_
(
offsetlist
)
nums_dict
[
BLOCK_M
][
'batch_ptr'
]
=
mamba2_metadata
.
batch_ptr
nums_dict
[
BLOCK_M
][
'token_chunk_offset_ptr'
]
=
(
mamba2_metadata
.
token_chunk_offset_ptr
)
# type: ignore
mamba2_metadata
.
nums_dict
=
nums_dict
return
mamba2_metadata
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
a903669e
...
@@ -10,8 +10,6 @@ import torch
...
@@ -10,8 +10,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
...
@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn
,
causal_conv1d_update
)
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
has_weight
=
rms_norm_has_weight
,
has_weight
=
rms_norm_has_weight
,
)
if
use_rms_norm
else
None
)
if
use_rms_norm
else
None
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The inner tuple is (conv_state, ssm_state)
# The outer list is for v0 PP virtual engine. Though this code path
self
.
kv_cache
=
(
torch
.
tensor
([]),
torch
.
tensor
([]))
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
return
discrete_time_step
,
B
,
C
return
discrete_time_step
,
B
,
C
def
forward
(
self
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
hidden_states
:
torch
.
Tensor
,
torch
.
ops
.
vllm
.
mamba_mixer
(
output
:
torch
.
Tensor
,
hidden_states
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
output
,
if
not
envs
.
VLLM_USE_V1
:
self
.
prefix
,
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
)
)
else
:
torch
.
ops
.
vllm
.
mamba_mixer
(
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
,
output
:
torch
.
Tensor
):
output
,
self
.
prefix
,
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
pass
pass
def
forward_cuda
(
self
,
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
):
"""
"""
Run the Mamba-1 SSM pipeline.
Run the Mamba-1 SSM pipeline.
...
@@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba1_metadata
=
attn_metadata
mamba1_metadata
=
attn_metadata
assert
isinstance
(
mamba1_metadata
,
Mamba1AttentionMetadata
)
assert
isinstance
(
mamba1_metadata
,
Mamba1AttentionMetadata
)
query_start_loc
=
mamba1_metadata
.
query_start_loc
query_start_loc
=
mamba1_metadata
.
query_start_loc
state_indices_tensor
=
mamba1_metadata
.
state_indices_tensor
state_indices_tensor
=
mamba1_metadata
.
state_indices_tensor
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
has_initial_states
=
mamba1_metadata
.
has_initial_states
has_initial_states
=
mamba1_metadata
.
has_initial_states
num_padded_decodes
=
mamba1_metadata
.
num_padded_decodes
num_padded_decodes
=
mamba1_metadata
.
num_padded_decodes
else
:
assert
isinstance
(
attn_metadata
,
AttentionMetadata
)
assert
mamba_cache_params
is
not
None
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
query_start_loc
=
attn_metadata
.
query_start_loc
context_lens_tensor
=
attn_metadata
.
context_lens_tensor
has_initial_states
=
None
if
context_lens_tensor
is
not
None
:
has_initial_states
=
context_lens_tensor
>
0
num_padded_decodes
=
attn_metadata
.
num_decode_tokens
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
self
.
conv1d
.
weight
.
size
(
2
))
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# V1 profile run
# V1 profile run
hidden_states_BC
=
hidden_states_BC
.
contiguous
()
hidden_states_BC
=
hidden_states_BC
.
contiguous
()
return
self
.
out_proj
(
hidden_states_BC
.
transpose
(
-
2
,
-
1
))[
0
]
return
self
.
out_proj
(
hidden_states_BC
.
transpose
(
-
2
,
-
1
))[
0
]
...
@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
out
=
scan_outputs_d
)
out
=
scan_outputs_d
)
scan_outputs_d
=
scan_outputs_d
.
transpose
(
0
,
1
)
scan_outputs_d
=
scan_outputs_d
.
transpose
(
0
,
1
)
if
envs
.
VLLM_USE_V1
:
ssm_outputs
.
insert
(
0
,
scan_outputs_d
)
ssm_outputs
.
insert
(
0
,
scan_outputs_d
)
else
:
ssm_outputs
.
append
(
scan_outputs_d
)
scan_outputs_combined
=
ssm_outputs
[
0
]
if
len
(
scan_outputs_combined
=
ssm_outputs
[
0
]
if
len
(
ssm_outputs
)
==
1
else
torch
.
cat
(
ssm_outputs
,
dim
=-
1
)
ssm_outputs
)
==
1
else
torch
.
cat
(
ssm_outputs
,
dim
=-
1
)
...
@@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
...
@@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
num_decodes
:
int
,
num_decodes
:
int
,
num_padded_decodes
:
int
,
num_padded_decodes
:
int
,
)
->
PrefillDecodeSplit
:
)
->
PrefillDecodeSplit
:
num_actual_tokens
=
num_prefill_tokens
+
num_padded_decodes
num_actual_tokens
=
num_prefill_tokens
+
num_padded_decodes
if
envs
.
VLLM_USE_V1
:
# In v1, decode tokens come first, then prefill tokens.
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d
,
hidden_states_BC_p
=
torch
.
split
(
hidden_states_BC_d
,
hidden_states_BC_p
=
torch
.
split
(
hidden_states_BC
[...,
:
num_actual_tokens
],
hidden_states_BC
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
dim
=-
1
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[...,
:
num_actual_tokens
],
gate_d
,
gate_p
=
torch
.
split
(
gate
[...,
:
num_actual_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
[
num_padded_decodes
,
num_prefill_tokens
],
dim
=-
1
)
dim
=-
1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_padded_decodes
+
num_prefills
],
state_indices_tensor
[:
num_padded_decodes
+
num_prefills
],
[
num_padded_decodes
,
num_prefills
],
[
num_padded_decodes
,
num_prefills
],
dim
=
0
)
dim
=
0
)
query_start_loc_p
=
(
query_start_loc
[
-
num_prefills
-
1
:]
-
query_start_loc_p
=
(
query_start_loc
[
-
num_prefills
-
1
:]
-
num_padded_decodes
if
num_prefills
>
0
else
None
)
num_padded_decodes
if
num_prefills
>
0
else
None
)
has_initial_states_p
=
has_initial_states
[
-
num_prefills
:]
if
(
has_initial_states_p
=
has_initial_states
[
-
num_prefills
:]
if
(
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
else
:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p
,
hidden_states_BC_d
=
torch
.
split
(
hidden_states_BC
,
[
num_prefill_tokens
,
num_decode_tokens
],
dim
=-
1
)
gate_p
,
gate_d
=
torch
.
split
(
gate
,
[
num_prefill_tokens
,
num_decode_tokens
],
dim
=-
1
)
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
)
query_start_loc_p
=
(
query_start_loc
[:
num_prefills
+
1
]
if
num_prefills
>
0
else
None
)
has_initial_states_p
=
has_initial_states
[:
num_prefills
]
if
(
has_initial_states
is
not
None
and
num_prefills
>
0
)
else
None
return
PrefillDecodeSplit
(
return
PrefillDecodeSplit
(
hidden_states_BC_p
=
hidden_states_BC_p
,
hidden_states_BC_p
=
hidden_states_BC_p
,
...
@@ -495,9 +448,7 @@ def mamba_mixer(
...
@@ -495,9 +448,7 @@ def mamba_mixer(
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
output
=
output
,
mamba_cache_params
=
None
)
def
mamba_mixer_fake
(
def
mamba_mixer_fake
(
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
a903669e
...
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
...
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
update_metadata
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
...
@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
...
@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self
.
use_rms_norm
,
self
.
use_rms_norm
,
eps
=
rms_norm_eps
)
eps
=
rms_norm_eps
)
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The tuple is (conv_state, ssm_state)
# The outer list is for v0 PP virtual engine. Though this code path
self
.
kv_cache
=
(
torch
.
tensor
([]),
torch
.
tensor
([]))
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
pass
pass
...
@@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
not
envs
.
VLLM_USE_V1
:
torch
.
ops
.
vllm
.
mamba_mixer2
(
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
,
hidden_states
,
mamba2_metadata
,
mup_vector
)
output
,
else
:
self
.
prefix
,
torch
.
ops
.
vllm
.
mamba_mixer2
(
mup_vector
,
hidden_states
,
)
output
,
self
.
prefix
,
mup_vector
,
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
#
mamba2
_metadata contains metadata necessary for the mamba2 triton
#
attn
_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba2_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
else
:
prep_initial_states
=
attn_metadata
.
prep_initial_states
conv_state
=
mamba_cache_params
.
conv_state
chunk_size
=
attn_metadata
.
chunk_size
ssm_state
=
mamba_cache_params
.
ssm_state
seq_idx_p
=
attn_metadata
.
seq_idx_p
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
chunk_indices_p
=
attn_metadata
.
chunk_indices_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets_p
# Common members between V1 metadata and V0 metadata
if
mamba2_metadata
is
not
None
:
has_initial_states_p
=
mamba2_metadata
.
has_initial_states_p
prep_initial_states
=
mamba2_metadata
.
prep_initial_states
chunk_size
=
mamba2_metadata
.
chunk_size
seq_idx_p
=
mamba2_metadata
.
seq_idx_p
chunk_indices_p
=
mamba2_metadata
.
chunk_indices_p
chunk_offsets_p
=
mamba2_metadata
.
chunk_offsets_p
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
...
@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim
=-
1
,
dim
=-
1
,
)
)
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
if
attn_metadata
is
None
:
#
V1
profile run
# profile run
hidden_states_B_C
=
(
hidden_states_B_C
.
transpose
(
hidden_states_B_C
=
(
hidden_states_B_C
.
transpose
(
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
hidden_states
,
_B
,
_C
=
split_hidden_states_B_C_fn
(
hidden_states
,
_B
,
_C
=
split_hidden_states_B_C_fn
(
...
@@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode
=
num_decodes
>
0
has_decode
=
num_decodes
>
0
num_actual_tokens
=
num_prefill_tokens
+
num_decodes
num_actual_tokens
=
num_prefill_tokens
+
num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# Split along token dimension
if
envs
.
VLLM_USE_V1
:
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
[:
num_actual_tokens
],
hidden_states_B_C
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
dim
=
0
,
)
)
dt_d
,
dt_p
=
torch
.
split
(
dt_d
,
dt_p
=
torch
.
split
(
dt
[:
num_actual_tokens
],
dt
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
dim
=
0
,
)
)
# Split along batch dimension
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_actual_tokens
],
state_indices_tensor
[:
num_actual_tokens
],
[
num_decodes
,
num_prefills
],
[
num_decodes
,
num_prefills
],
dim
=
0
,
dim
=
0
,
)
)
query_start_loc_p
=
(
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
num_decodes
if
has_prefill
else
None
)
else
:
hidden_states_B_C_p
,
hidden_states_B_C_d
=
torch
.
split
(
hidden_states_B_C
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
dt_p
,
dt_d
=
torch
.
split
(
dt
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
if
has_prefill
else
None
)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
# and decode outputs
...
@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
if
envs
.
VLLM_USE_V1
:
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out
,
preallocated_ssm_out
,
[
num_decodes
,
num_prefill_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
dim
=
0
,
)
)
else
:
preallocated_ssm_out_p
,
preallocated_ssm_out_d
=
torch
.
split
(
preallocated_ssm_out
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
# Process prefill requests
# Process prefill requests
if
has_prefill
:
if
has_prefill
:
...
@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
# pointed to by "state_indices_tensor"
x
=
hidden_states_B_C_p
.
transpose
(
x
=
hidden_states_B_C_p
.
transpose
(
0
,
1
)
# this is the form that causal-conv see
0
,
1
)
# this is the form that causal-conv see
if
mamba2_metadata
.
cu_seqlen
is
None
:
mamba2_metadata
=
update_metadata
(
x
,
query_start_loc_p
,
mamba2_metadata
)
hidden_states_B_C_p
=
causal_conv1d_fn
(
hidden_states_B_C_p
=
causal_conv1d_fn
(
x
,
x
,
conv_weights
,
conv_weights
,
...
@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states
=
conv_state
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
mamba2
_metadata
,
metadata
=
attn
_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
0
,
1
)[:
num_prefill_tokens
]
...
@@ -806,8 +748,6 @@ def mamba_mixer2(
...
@@ -806,8 +748,6 @@ def mamba_mixer2(
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
output
=
output
,
mamba_cache_params
=
None
,
mamba2_metadata
=
None
,
mup_vector
=
mup_vector
)
mup_vector
=
mup_vector
)
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
a903669e
...
@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
...
@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
intermediate_size
:
int
,
intermediate_size
:
int
,
state_size
:
int
,
state_size
:
int
,
conv_kernel
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
conv_state_shape
=
(
divide
(
intermediate_size
,
conv_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
conv_kernel
-
1
)
tp_world_size
),
conv_kernel
-
1
)
...
@@ -108,11 +107,7 @@ class MambaStateShapeCalculator:
...
@@ -108,11 +107,7 @@ class MambaStateShapeCalculator:
temporal_state_shape
=
(
divide
(
intermediate_size
,
temporal_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
state_size
)
tp_world_size
),
state_size
)
# In V0, the conv_state shape was swapped during allocation in
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
conv_state_shape
,
temporal_state_shape
return
conv_state_shape
,
temporal_state_shape
...
@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
...
@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
head_dim
:
int
,
head_dim
:
int
,
state_size
:
int
,
state_size
:
int
,
conv_kernel
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
# if n_groups is not divisible by world_size, need to extend the shards
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
# to ensure all groups needed by a head is sharded along with it
...
@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
...
@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
# contiguous along 'dim' axis
# contiguous along 'dim' axis
conv_state_shape
=
(
conv_kernel
-
1
,
divide
(
conv_dim
,
tp_world_size
))
conv_state_shape
=
(
conv_kernel
-
1
,
divide
(
conv_dim
,
tp_world_size
))
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
# These are not TP-ed as they depend on A, dt_bias, D
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# - they are typically small
...
@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
...
@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
tp_world_size
:
int
,
tp_world_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
conv_kernel
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
(
conv_state_shape
,
)
return
(
conv_state_shape
,
)
@
classmethod
@
classmethod
...
@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
...
@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
head_v_dim
:
int
,
head_v_dim
:
int
,
conv_kernel_size
:
int
,
conv_kernel_size
:
int
,
num_spec
:
int
=
0
,
num_spec
:
int
=
0
,
use_v1
:
bool
=
True
,
):
):
conv_dim
=
(
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
)
conv_dim
=
(
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
)
conv_state_shape
=
(
conv_state_shape
=
(
...
@@ -191,11 +179,7 @@ class MambaStateShapeCalculator:
...
@@ -191,11 +179,7 @@ class MambaStateShapeCalculator:
conv_kernel_size
-
1
+
num_spec
,
conv_kernel_size
-
1
+
num_spec
,
)
)
# In V0, the conv_state shape was swapped during allocation in
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
temporal_state_shape
=
(
divide
(
num_v_heads
,
temporal_state_shape
=
(
divide
(
num_v_heads
,
tp_world_size
),
head_k_dim
,
head_v_dim
)
tp_world_size
),
head_k_dim
,
head_v_dim
)
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
a903669e
...
@@ -420,9 +420,7 @@ def causal_conv1d_fn(
...
@@ -420,9 +420,7 @@ def causal_conv1d_fn(
x
=
x
.
to
(
conv_states
.
dtype
)
x
=
x
.
to
(
conv_states
.
dtype
)
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
if
metadata
is
not
None
:
if
metadata
is
not
None
:
cu_seqlen
=
metadata
.
cu_seqlen
nums_dict
=
metadata
.
nums_dict
nums_dict
=
metadata
.
nums_dict
#x = metadata.x
args
=
nums_dict
args
=
nums_dict
batch_ptr
=
metadata
.
batch_ptr
batch_ptr
=
metadata
.
batch_ptr
token_chunk_offset_ptr
=
metadata
.
token_chunk_offset_ptr
token_chunk_offset_ptr
=
metadata
.
token_chunk_offset_ptr
...
@@ -926,7 +924,6 @@ def causal_conv1d_update(
...
@@ -926,7 +924,6 @@ def causal_conv1d_update(
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
max_query_len
:
int
=
-
1
,
max_query_len
:
int
=
-
1
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
metadata
=
None
,
validate_data
=
False
,
validate_data
=
False
,
):
):
"""
"""
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
a903669e
...
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
...
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
import
torch
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
...
@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
prefix
=
f
"
{
prefix
}
.out_proj"
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
)
assert
envs
.
VLLM_USE_V1
,
(
"ShortConv layers are only supported in V1"
)
compilation_config
=
get_current_vllm_config
().
compilation_config
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
self
.
kv_cache
=
(
torch
.
tensor
([]),
)
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self
.
kv_cache
=
[(
torch
.
tensor
([]),
)]
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
):
return
return
...
@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
):
torch
.
ops
.
vllm
.
short_conv
(
torch
.
ops
.
vllm
.
short_conv
(
hidden_states
,
hidden_states
,
...
@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
):
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
# ShortConvAttentionMetadata contains metadata necessary for the
# ShortConvAttentionMetadata contains metadata necessary for the
...
@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
...
@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
if
has_prefill
:
if
has_prefill
:
Bx_p
=
(
B_p
*
x_p
).
transpose
(
0
,
1
)
Bx_p
=
(
B_p
*
x_p
).
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
Bx_p
,
query_start_loc_p
,
conv_metadata
)
Bx
=
causal_conv1d_fn
(
Bx_p
,
Bx
=
causal_conv1d_fn
(
Bx_p
,
conv_weights
,
conv_weights
,
self
.
conv
.
bias
,
self
.
conv
.
bias
,
...
@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
conv_states
=
conv_state
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
conv
_metadata
,
metadata
=
attn
_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
0
,
1
)[:
num_prefill_tokens
]
...
@@ -248,9 +235,7 @@ def short_conv(
...
@@ -248,9 +235,7 @@ def short_conv(
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
output
=
output
,
conv_metadata
=
None
)
def
short_conv_fake
(
def
short_conv_fake
(
...
...
vllm/model_executor/models/bamba.py
View file @
a903669e
...
@@ -9,21 +9,17 @@ import torch
...
@@ -9,21 +9,17 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BambaConfig
from
transformers
import
BambaConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
)
SupportsQuant
)
...
@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
)
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
@@ -315,22 +306,10 @@ class BambaModel(nn.Module):
...
@@ -315,22 +306,10 @@ class BambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -343,23 +322,11 @@ class BambaModel(nn.Module):
...
@@ -343,23 +322,11 @@ class BambaModel(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
None
residual
=
None
num_attn
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
layer
,
BambaAttentionDecoderLayer
):
num_attn
+=
1
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
BambaMixerDecoderLayer
)
and
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
num_attn
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
def
get_mamba_state_shape_from_config
(
cls
,
cls
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
Args:
vllm_config: vLLM config
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Returns:
Tuple containing:
Tuple containing:
...
@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim
=
hf_config
.
mamba_d_head
,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
...
@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
mamba_cache_params
=
None
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
if
not
envs
.
VLLM_USE_V1
:
inputs_embeds
)
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
\
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/constant_size_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
class
ConstantSizeCache
(
ABC
):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def
__init__
(
self
,
max_batch_size
:
int
):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self
.
cache_indices_mapping
:
dict
[
str
,
dict
[
int
,
int
]]
=
{}
self
.
free_cache_indices
=
list
(
range
(
max_batch_size
))
@
property
@
abstractmethod
def
cache
(
self
)
->
Any
:
"""Return the underlying cache tensor(s)"""
pass
@
abstractmethod
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
"""Copy cache data from one index to another"""
pass
def
current_run_tensors
(
self
,
**
kwargs
)
->
tuple
:
"""
Return the tensors for the current run's conv and ssm state.
"""
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices_tensor
=
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cache_tensors
=
self
.
cache
else
:
# CUDA graph capturing runs
cache_tensors
,
state_indices_tensor
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
return
(
cache_tensors
,
state_indices_tensor
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
assert
"seqlen_agnostic_capture_inputs"
in
input_buffers
_
,
input_state_indices_buffer
=
input_buffers
[
"seqlen_agnostic_capture_inputs"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
cuda_graph_pad_len
=
input_state_indices_buffer
.
shape
[
0
]
-
len
(
state_indices
)
state_indices
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_len
)
input_state_indices_buffer
.
copy_
(
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
))
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor
=
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
(
self
.
cache
,
state_indices_tensor
)
def
_assign_seq_id_to_cache_index
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
finished_requests_ids
)
->
int
:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if
cur_rid
in
finished_requests_ids
:
# set as pad, do not allocate destination index
return
PAD_SLOT_ID
elif
cur_rid
not
in
self
.
cache_indices_mapping
:
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
return
destination_index
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists
=
next
(
iter
(
seq_ids2indices
.
values
()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
_copy_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
return
destination_index
else
:
return
self
.
cache_indices_mapping
[
cur_rid
][
seq_id
]
def
_prepare_current_run_cache
(
self
,
request_ids_to_seq_ids
:
dict
[
str
,
list
[
int
]],
finished_requests_ids
:
list
[
str
])
->
list
[
int
]:
return
[
self
.
_assign_seq_id_to_cache_index
(
req_id
,
seq_id
,
finished_requests_ids
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
def
_release_finished_requests
(
self
,
finished_seq_groups_req_ids
:
list
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
cache_indices_mapping
:
for
seq_id
in
self
.
cache_indices_mapping
[
req_id
]:
self
.
free_cache_indices
.
append
(
self
.
cache_indices_mapping
[
req_id
][
seq_id
])
self
.
cache_indices_mapping
.
pop
(
req_id
)
vllm/model_executor/models/falcon_h1.py
View file @
a903669e
...
@@ -8,21 +8,17 @@ import torch
...
@@ -8,21 +8,17 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
FalconH1Config
from
transformers
import
FalconH1Config
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
...
@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
...
@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
self
.
mamba
(
hidden_states
,
hidden_states
,
output
,
output
,
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
mup_vector
=
self
.
mup_vector
,
mup_vector
=
self
.
mup_vector
,
)
)
return
output
,
residual
return
output
,
residual
...
@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
...
@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
residual
=
hidden_states
residual
=
hidden_states
...
@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
...
@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
# Process input through the SSM branch.
# Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual,
mamba_cache_params,
and sequence_idx.
# residual, and sequence_idx.
ssm_hidden
,
_
=
self
.
mamba
(
ssm_hidden
,
_
=
self
.
mamba
(
hidden_states
=
hidden_states
*
self
.
ssm_in_multiplier
,
hidden_states
=
hidden_states
*
self
.
ssm_in_multiplier
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
**
kwargs
,
**
kwargs
,
)
)
# Sum the outputs from both branches.
# Sum the outputs from both branches.
...
@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
...
@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
*
self
.
embedding_multiplier
hidden_states
=
inputs_embeds
*
self
.
embedding_multiplier
...
@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
...
@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
layer_mamba_cache_params
=
None
if
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
)
hidden_states
=
layer
(
hidden_states
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
def
get_mamba_state_shape_from_config
(
cls
,
cls
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
Args:
vllm_config: vLLM config
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Returns:
Tuple containing:
Tuple containing:
...
@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim
=
hf_config
.
mamba_d_head
,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
tie_word_embeddings
=
config
.
tie_word_embeddings
self
.
tie_word_embeddings
=
config
.
tie_word_embeddings
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
...
@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
**
kwargs
,
**
kwargs
,
):
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
self
.
config
.
num_hidden_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
,
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
,
inputs_embeds
,
)
)
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/granitemoehybrid.py
View file @
a903669e
...
@@ -9,19 +9,15 @@ import torch
...
@@ -9,19 +9,15 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
GraniteMoeHybridConfig
from
transformers
import
GraniteMoeHybridConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.granitemoe
import
GraniteMoeMoE
from
.granitemoe
import
GraniteMoeMoE
from
.granitemoeshared
import
GraniteMoeSharedMLP
from
.granitemoeshared
import
GraniteMoeSharedMLP
...
@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
...
@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
)
hidden_states
=
residual
+
output
*
self
.
residual_multiplier
hidden_states
=
residual
+
output
*
self
.
residual_multiplier
residual
=
hidden_states
residual
=
hidden_states
...
@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
...
@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
...
@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
...
@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
layer
,
GraniteMoeHybridAttentionDecoderLayer
):
if
isinstance
(
layer
,
GraniteMoeHybridAttentionDecoderLayer
):
num_attn
+=
1
num_attn
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
layer_mamba_cache_params
=
None
hidden_states
=
hidden_states
,
if
isinstance
(
residual
=
residual
)
layer
,
GraniteMoeHybridMambaDecoderLayer
)
and
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
num_attn
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
...
@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def
get_mamba_state_shape_from_config
(
def
get_mamba_state_shape_from_config
(
cls
,
cls
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
Args:
vllm_config: vLLM config
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Returns:
Tuple containing:
Tuple containing:
...
@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
...
@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
head_dim
=
hf_config
.
mamba_d_head
,
head_dim
=
hf_config
.
mamba_d_head
,
state_size
=
hf_config
.
mamba_d_state
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
...
@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
scale
=
1
/
scale
=
1
/
self
.
config
.
logits_scaling
)
self
.
config
.
logits_scaling
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
...
@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
mamba_cache_params
=
None
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
if
not
envs
.
VLLM_USE_V1
:
inputs_embeds
)
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/jamba.py
View file @
a903669e
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
JambaConfig
from
transformers
import
JambaConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
...
@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaMLP
as
JambaMLP
from
vllm.model_executor.models.llama
import
LlamaMLP
as
JambaMLP
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
...
@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
)
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
)
self
.
mamba
(
hidden_states
,
output
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
...
@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
...
@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
kv_cache_index
=
0
mamba_cache_index
=
0
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
layer_mamba_cache_params
=
None
hidden_states
,
residual
=
layer
(
positions
=
positions
,
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
hidden_states
=
hidden_states
,
kv_cache_index
+=
1
residual
=
residual
)
if
isinstance
(
layer
,
JambaMambaDecoderLayer
)
and
mamba_cache_params
:
current_state_layer
=
mamba_cache_index
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
current_state_layer
)
mamba_cache_index
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
...
@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
if
not
envs
.
VLLM_USE_V1
:
inputs_embeds
)
if
self
.
mamba_cache
is
None
:
num_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
state_shape
=
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
)
state_dtype
=
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_layers
,
*
state_shape
,
*
state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size
=
hf_config
.
mamba_expand
*
hidden_size
,
intermediate_size
=
hf_config
.
mamba_expand
*
hidden_size
,
state_size
=
hf_config
.
mamba_d_state
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
envs
.
VLLM_USE_V1
,
)
)
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/lfm2.py
View file @
a903669e
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
Lfm2Config
from
transformers
import
Lfm2Config
from
vllm
import
envs
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
...
@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
...
@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self
.
conv
(
self
.
conv
(
hidden_states
,
hidden_states
,
output
,
output
,
conv_metadata
=
None
,
)
)
hidden_states
,
residual
=
self
.
ffn_norm
(
output
,
residual
)
hidden_states
,
residual
=
self
.
ffn_norm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
...
@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
def
get_mamba_state_shape_from_config
(
cls
,
cls
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
]]:
""" Calculate shapes for LFM2's convolutional cache.
""" Calculate shapes for LFM2's convolutional cache.
Args:
Args:
vllm_config: vLLM config
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Returns:
Tuple containing:
Tuple containing:
...
@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
intermediate_size
=
hf_config
.
conv_dim
,
intermediate_size
=
hf_config
.
conv_dim
,
conv_kernel
=
hf_config
.
conv_L_cache
,
conv_kernel
=
hf_config
.
conv_L_cache
,
use_v1
=
use_v1
,
)
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
...
@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config
=
vllm_config
.
scheduler_config
scheduler_config
=
vllm_config
.
scheduler_config
assert
(
not
cache_config
.
enable_prefix_caching
assert
(
not
cache_config
.
enable_prefix_caching
),
"Lfm2 currently does not support prefix caching"
),
"Lfm2 currently does not support prefix caching"
assert
envs
.
VLLM_USE_V1
,
(
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1"
)
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
...
vllm/model_executor/models/mamba.py
View file @
a903669e
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MambaConfig
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
,
SupportsPP
)
IsAttentionFree
,
SupportsPP
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
...
@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
...
@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
)
self
.
mixer
(
hidden_states
,
output
)
return
output
,
residual
return
output
,
residual
...
@@ -134,7 +129,6 @@ class MambaModel(nn.Module):
...
@@ -134,7 +129,6 @@ class MambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
Optional
[
MambaCacheParams
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -151,17 +145,9 @@ class MambaModel(nn.Module):
...
@@ -151,17 +145,9 @@ class MambaModel(nn.Module):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
=
positions
,
layer_cache_params
=
None
hidden_states
=
hidden_states
,
if
mamba_cache_params
is
not
None
:
residual
=
residual
)
layer_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_cache_params
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
...
@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
...
@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
...
@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
mamba_cache_params
=
None
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
state_shape
=
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
)
state_dtype
=
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_layers
,
*
state_shape
,
*
state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
...
@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
intermediate_size
=
hf_config
.
intermediate_size
,
intermediate_size
=
hf_config
.
intermediate_size
,
state_size
=
hf_config
.
state_size
,
state_size
=
hf_config
.
state_size
,
conv_kernel
=
hf_config
.
conv_kernel
,
conv_kernel
=
hf_config
.
conv_kernel
)
use_v1
=
envs
.
VLLM_USE_V1
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
...
...
vllm/model_executor/models/mamba2.py
View file @
a903669e
...
@@ -8,16 +8,11 @@ import torch
...
@@ -8,16 +8,11 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MambaConfig
from
transformers
import
MambaConfig
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mixer
(
hidden_states
,
output
)
return
output
,
residual
return
output
,
residual
...
@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
...
@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
...
@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
residual
)
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
)
if
mamba_cache_params
else
None
,
mamba2_metadata
=
mamba2_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def
get_mamba_state_shape_from_config
(
def
get_mamba_state_shape_from_config
(
cls
,
cls
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
Args:
vllm_config: vLLM config
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Returns:
Tuple containing:
Tuple containing:
...
@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim
=
hf_config
.
head_dim
,
head_dim
=
hf_config
.
head_dim
,
state_size
=
hf_config
.
state_size
,
state_size
=
hf_config
.
state_size
,
conv_kernel
=
hf_config
.
conv_kernel
,
conv_kernel
=
hf_config
.
conv_kernel
,
use_v1
=
use_v1
,
)
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
backbone
.
embeddings
)
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
backbone
.
embeddings
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
...
@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
else
:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/mamba_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.model_executor.models.constant_size_cache
import
ConstantSizeCache
@
dataclass
class
MambaCacheParams
:
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MambaCacheParams
(
self
.
conv_state
[
layer_idx
],
self
.
ssm_state
[
layer_idx
],
self
.
state_indices_tensor
)
class
MambaCacheManager
(
ConstantSizeCache
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
num_mamba_layers
:
int
,
conv_state_shape
:
tuple
[
int
,
int
],
temporal_state_shape
:
tuple
[
int
,
int
],
conv_state_dtype
:
torch
.
dtype
,
temporal_state_dtype
:
torch
.
dtype
):
self
.
conv_state_dtype
=
conv_state_dtype
self
.
temporal_state_dtype
=
temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
vllm_config
.
model_config
.
enforce_eager
:
max_batch_size
=
vllm_config
.
pad_for_cudagraph
(
max_batch_size
)
# Initialize parent class
super
().
__init__
(
max_batch_size
)
# assume conv_state = (dim, state_len)
assert
conv_state_shape
[
0
]
>
conv_state_shape
[
1
]
conv_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
(
conv_state_shape
[
1
],
conv_state_shape
[
0
]),
dtype
=
self
.
conv_state_dtype
,
device
=
"cuda"
).
transpose
(
-
1
,
-
2
)
temporal_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
self
.
temporal_state_dtype
,
device
=
"cuda"
)
self
.
_mamba_cache
=
(
conv_state
,
temporal_state
)
@
property
def
cache
(
self
):
return
self
.
_mamba_cache
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
for
cache_t
in
self
.
cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
current_run_tensors
(
self
,
**
kwargs
)
->
MambaCacheParams
:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors
,
state_indices_tensor
=
super
().
current_run_tensors
(
**
kwargs
)
return
MambaCacheParams
(
cache_tensors
[
0
],
cache_tensors
[
1
],
state_indices_tensor
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
self
.
_mamba_cache
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
vllm/model_executor/models/minimax_cache.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.model_executor.models.constant_size_cache
import
ConstantSizeCache
@
dataclass
class
MinimaxCacheParams
:
minimax_cache
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MinimaxCacheParams
(
self
.
minimax_cache
[
layer_idx
,
...],
self
.
state_indices_tensor
)
class
MinimaxCacheManager
(
ConstantSizeCache
):
def
__init__
(
self
,
dtype
,
cache_shape
):
super
().
__init__
(
cache_shape
[
1
])
# max_batch_size is cache_shape[1]
self
.
_minimax_cache
=
torch
.
empty
(
size
=
cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
property
def
cache
(
self
):
return
self
.
_minimax_cache
def
_copy_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
cache
)
>
0
for
cache_t
in
self
.
cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment