Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
881e3cbe
Unverified
Commit
881e3cbe
authored
Jul 19, 2025
by
Thomas Parnell
Committed by
GitHub
Jul 19, 2025
Browse files
[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers (#21194)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
9f414a12
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
100 additions
and
31 deletions
+100
-31
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+0
-1
vllm/config.py
vllm/config.py
+1
-0
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+66
-9
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+6
-5
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+6
-2
vllm/model_executor/models/granitemoehybrid.py
vllm/model_executor/models/granitemoehybrid.py
+5
-3
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+5
-3
vllm/model_executor/models/nemotron_h.py
vllm/model_executor/models/nemotron_h.py
+5
-3
vllm/model_executor/models/zamba2.py
vllm/model_executor/models/zamba2.py
+6
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+0
-3
No files found.
tests/models/language/generation/test_hybrid.py
View file @
881e3cbe
...
@@ -104,7 +104,6 @@ def test_models(
...
@@ -104,7 +104,6 @@ def test_models(
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
max_num_seqs
=
MAX_NUM_SEQS
,
enforce_eager
=
True
,
enable_prefix_caching
=
False
)
as
vllm_model
:
enable_prefix_caching
=
False
)
as
vllm_model
:
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
vllm/config.py
View file @
881e3cbe
...
@@ -4312,6 +4312,7 @@ class CompilationConfig:
...
@@ -4312,6 +4312,7 @@ class CompilationConfig:
self
.
splitting_ops
=
[]
if
self
.
full_cuda_graph
else
[
self
.
splitting_ops
=
[]
if
self
.
full_cuda_graph
else
[
"vllm.unified_attention"
,
"vllm.unified_attention"
,
"vllm.unified_attention_with_output"
,
"vllm.unified_attention_with_output"
,
"vllm.mamba_mixer2"
,
]
]
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
881e3cbe
...
@@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionMetadata
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionMetadata
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
...
@@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
def
forward_native
(
def
forward_native
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
pass
pass
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
not
envs
.
VLLM_USE_V1
:
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
,
mup_vector
)
else
:
torch
.
ops
.
vllm
.
mamba_mixer2
(
hidden_states
,
output
,
self
.
prefix
,
mup_vector
,
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
has_decode
=
num_decodes
>
0
num_actual_tokens
=
num_prefill_tokens
+
num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Separate prefill and decode by splitting varlen input
...
@@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C_d
,
hidden_states_B_C_p
=
torch
.
split
(
hidden_states_B_C
,
hidden_states_B_C
[:
num_actual_tokens
]
,
[
num_decodes
,
num_prefill_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
dim
=
0
,
)
)
dt_d
,
dt_p
=
torch
.
split
(
dt_d
,
dt_p
=
torch
.
split
(
dt
,
dt
[:
num_actual_tokens
]
,
[
num_decodes
,
num_prefill_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
dim
=
0
,
)
)
# Split along batch dimension
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
state_indices_tensor
[:
num_actual_tokens
]
,
[
num_decodes
,
num_prefills
],
[
num_decodes
,
num_prefills
],
dim
=
0
,
dim
=
0
,
)
)
...
@@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
# GatedRMSNorm internally applying SiLU to the gate
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# SiLU is applied internally before normalization, unlike standard
# norm usage
# norm usage
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
hidden_states
=
self
.
norm
(
hidden_states
,
gate
[:
num_actual_tokens
]
)
# 5. Final linear projection
# 5. Final linear projection
out
,
_
=
self
.
out_proj
(
hidden_states
)
output
[:
num_actual_tokens
],
_
=
self
.
out_proj
(
hidden_states
)
return
out
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
return
get_mamba_state_shape
(
return
get_mamba_state_shape
(
...
@@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
state_size
=
self
.
ssm_state_size
,
state_size
=
self
.
ssm_state_size
,
conv_kernel
=
self
.
conv_kernel_size
,
conv_kernel
=
self
.
conv_kernel_size
,
)
)
def
mamba_mixer2
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
mamba_cache_params
=
None
,
mamba2_metadata
=
None
,
mup_vector
=
mup_vector
)
def
mamba_mixer2_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"mamba_mixer2"
,
op_func
=
mamba_mixer2
,
mutates_args
=
[
"output"
],
fake_impl
=
mamba_mixer2_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/models/bamba.py
View file @
881e3cbe
...
@@ -11,6 +11,7 @@ from transformers import BambaConfig
...
@@ -11,6 +11,7 @@ from transformers import BambaConfig
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
output
=
torch
.
empty_like
(
hidden_states
)
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
output
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module):
...
@@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
if
hasattr
(
config
,
"partial_rotary_factor"
):
if
hasattr
(
config
,
"partial_rotary_factor"
):
rotary_dim
=
self
.
head_dim
*
config
.
partial_rotary_factor
rotary_dim
=
int
(
self
.
head_dim
*
config
.
partial_rotary_factor
)
elif
hasattr
(
config
,
"attn_rotary_emb"
):
elif
hasattr
(
config
,
"attn_rotary_emb"
):
rotary_dim
=
config
.
attn_rotary_emb
# for backward compatibility
rotary_dim
=
config
.
attn_rotary_emb
# for backward compatibility
else
:
else
:
...
@@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = {
...
@@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = {
}
}
@
support_torch_compile
class
BambaModel
(
nn
.
Module
):
class
BambaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/falcon_h1.py
View file @
881e3cbe
...
@@ -10,6 +10,7 @@ from transformers import FalconH1Config
...
@@ -10,6 +10,7 @@ from transformers import FalconH1Config
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module):
...
@@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module):
mamba2_metadata
:
Mamba2Metadata
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
hidden_states
=
self
.
mamba
(
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
hidden_states
,
output
,
mamba_cache_params
,
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
mamba2_metadata
=
mamba2_metadata
,
mup_vector
=
self
.
mup_vector
,
mup_vector
=
self
.
mup_vector
,
)
)
return
hidden_states
,
residual
return
output
,
residual
class
FalconH1AttentionDecoderLayer
(
nn
.
Module
):
class
FalconH1AttentionDecoderLayer
(
nn
.
Module
):
...
@@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module):
...
@@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module):
return
hidden_states
return
hidden_states
@
support_torch_compile
class
FalconH1Model
(
nn
.
Module
):
class
FalconH1Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/granitemoehybrid.py
View file @
881e3cbe
...
@@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig
...
@@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
...
@@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
):
):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
output
=
torch
.
empty_like
(
hidden_states
)
mamba2_metadata
)
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
hidden_states
=
residual
+
hidden_states
*
self
.
residual_multiplier
hidden_states
=
residual
+
output
*
self
.
residual_multiplier
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
...
@@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = {
...
@@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = {
}
}
@
support_torch_compile
class
GraniteMoeHybridModel
(
nn
.
Module
):
class
GraniteMoeHybridModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/mamba2.py
View file @
881e3cbe
...
@@ -10,6 +10,7 @@ from transformers import MambaConfig
...
@@ -10,6 +10,7 @@ from transformers import MambaConfig
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
...
@@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
output
=
torch
.
empty_like
(
hidden_states
)
mamba2_metadata
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
return
hidden_states
,
residual
return
output
,
residual
@
support_torch_compile
class
Mamba2Model
(
nn
.
Module
):
class
Mamba2Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/nemotron_h.py
View file @
881e3cbe
...
@@ -25,6 +25,7 @@ from torch import nn
...
@@ -25,6 +25,7 @@ from torch import nn
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module):
...
@@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
output
=
torch
.
empty_like
(
hidden_states
)
mamba2_metadata
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
return
hidden_states
,
residual
return
output
,
residual
class
NemotronHAttention
(
nn
.
Module
):
class
NemotronHAttention
(
nn
.
Module
):
...
@@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = {
...
@@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = {
}
}
@
support_torch_compile
class
NemotronHModel
(
nn
.
Module
):
class
NemotronHModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/zamba2.py
View file @
881e3cbe
...
@@ -17,6 +17,7 @@ from transformers import Zamba2Config
...
@@ -17,6 +17,7 @@ from transformers import Zamba2Config
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
...
@@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module):
...
@@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# Process through Mamba mixer
# Process through Mamba mixer
hidden_states
=
self
.
mamba
(
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mamba
(
hidden_states
,
hidden_states
,
output
,
mamba_cache_params
=
mamba_cache_params
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
mamba2_metadata
=
mamba2_metadata
,
)
)
# residual connection after mamba
# residual connection after mamba
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
output
return
hidden_states
return
hidden_states
...
@@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module):
...
@@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module):
return
layer_outputs
return
layer_outputs
@
support_torch_compile
class
Zamba2Model
(
nn
.
Module
):
class
Zamba2Model
(
nn
.
Module
):
"""Core Zamba2 model combining transformer and Mamba architectures.
"""Core Zamba2 model combining transformer and Mamba architectures.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
881e3cbe
...
@@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
vllm_config
.
speculative_config
is
not
None
:
if
self
.
vllm_config
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
)
"Mamba with speculative decoding is not supported yet."
)
if
not
self
.
vllm_config
.
model_config
.
enforce_eager
:
raise
NotImplementedError
(
"Mamba with cuda graph is not supported yet."
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Prefix caching is not supported for Mamba yet."
)
"Prefix caching is not supported for Mamba yet."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment