Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a903669e
Unverified
Commit
a903669e
authored
Sep 23, 2025
by
Thomas Parnell
Committed by
GitHub
Sep 23, 2025
Browse files
[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
2c58742d
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
153 additions
and
1147 deletions
+153
-1147
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+0
-39
vllm/model_executor/models/nemotron_h.py
vllm/model_executor/models/nemotron_h.py
+3
-69
vllm/model_executor/models/phi4flash.py
vllm/model_executor/models/phi4flash.py
+0
-731
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+59
-181
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+8
-28
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+0
-1
vllm/model_executor/models/zamba2.py
vllm/model_executor/models/zamba2.py
+0
-93
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+7
-1
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+12
-3
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+13
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+51
-0
No files found.
vllm/model_executor/models/minimax_text_01.py
View file @
a903669e
...
...
@@ -14,7 +14,6 @@ import torch.distributed
from
torch
import
nn
from
transformers
import
MiniMaxConfig
from
vllm
import
envs
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
...
...
@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
from
.minimax_cache
import
MinimaxCacheManager
,
MinimaxCacheParams
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
...
@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
Union
[
list
[
dict
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
is_warmup
:
bool
=
False
,
...
...
@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states
=
layernorm_output
,
output
=
self_attention_output
,
positions
=
positions
,
kv_caches
=
kv_caches
,
)
residual
=
residual
*
self
.
layernorm_attention_alpha
...
...
@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self
.
_dtype
=
_dummy
.
dtype
del
_dummy
if
not
envs
.
VLLM_USE_V1
:
self
.
minimax_cache
=
MinimaxCacheManager
(
dtype
=
torch
.
float32
,
cache_shape
=
self
.
cache_shape
)
norm_kwargs
=
{}
if
hasattr
(
config
,
"rms_norm_eps"
):
norm_kwargs
[
"eps"
]
=
config
.
rms_norm_eps
...
...
@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**
kwargs
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
not
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
return
None
if
not
envs
.
VLLM_USE_V1
:
if
"request_ids_to_seq_ids"
not
in
kwargs
:
kwargs
[
"request_ids_to_seq_ids"
]
=
{}
if
"finished_requests_ids"
not
in
kwargs
:
kwargs
[
"finished_requests_ids"
]
=
[]
(
minimax_cache_tensors
,
state_indices_tensor
,
)
=
self
.
minimax_cache
.
current_run_tensors
(
**
kwargs
)
if
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
>
0
:
self
.
_clear_prefill_cache
(
attn_metadata
,
minimax_cache_tensors
,
**
kwargs
)
minimax_cache_params
=
MinimaxCacheParams
(
minimax_cache_tensors
,
state_indices_tensor
)
else
:
minimax_cache_params
=
None
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
None
:
...
...
@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
minimax_cache_index
=
0
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
_caches
=
None
if
not
envs
.
VLLM_USE_V1
and
isinstance
(
layer
.
self_attn
,
MiniMaxText01LinearAttention
):
current_state_layer
=
minimax_cache_index
_caches
=
minimax_cache_params
.
at_layer_idx
(
current_state_layer
)
minimax_cache_index
+=
1
hidden_states
,
residual
=
layer
(
hidden_states
=
hidden_states
,
positions
=
positions
,
kv_caches
=
_caches
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
)
...
...
@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
...],
...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
vllm/model_executor/models/nemotron_h.py
View file @
a903669e
...
...
@@ -23,21 +23,17 @@ from typing import Optional
import
torch
from
torch
import
nn
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
ReLUSquaredActivation
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
NemotronHConfig
from
vllm.utils
import
LayerBlockType
class
NemotronHMLP
(
nn
.
Module
):
...
...
@@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
mixer
(
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
self
.
mixer
(
hidden_states
,
output
)
return
output
,
residual
...
...
@@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
...
...
@@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
None
num_non_mamba_layers
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
NemotronHMambaDecoderLayer
)
and
mamba_cache_params
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
num_non_mamba_layers
)
else
:
num_non_mamba_layers
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim
=
hf_config
.
mamba_head_dim
,
state_size
=
hf_config
.
ssm_state_size
,
conv_kernel
=
hf_config
.
conv_kernel
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
\
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/phi4flash.py
deleted
100644 → 0
View file @
2c58742d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
transformers.activations
import
ACT2FN
import
vllm.envs
as
envs
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsV0Only
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
make_layers
,
maybe_prefix
logger
=
init_logger
(
__name__
)
class
SwiGLUActivation
(
nn
.
Module
):
def
forward
(
self
,
x1
:
torch
.
Tensor
,
x2
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x1
*
nn
.
functional
.
silu
(
x2
)
class
SambaYMLP
(
nn
.
Module
):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
2
*
config
.
intermediate_size
,
bias
=
False
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
hidden_states
):
y
=
self
.
fc1
(
hidden_states
)
gate
,
y
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation_fn
(
gate
)
return
self
.
fc2
(
y
)
def
get_virtual_engine
():
forward_context
:
ForwardContext
=
get_forward_context
()
return
forward_context
.
virtual_engine
class
SambaYAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
,
yoco_cross
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
if
layer_idx
is
None
:
logger
.
warning_once
(
f
"Instantiating
{
self
.
__class__
.
__name__
}
without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class."
)
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
yoco_cross
=
yoco_cross
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
"hidden_size must be divisible by num_heads "
f
"(got `hidden_size`:
{
self
.
hidden_size
}
and "
f
"`num_heads`:
{
self
.
num_heads
}
)."
)
op_size
=
self
.
num_heads
*
self
.
head_dim
+
2
*
(
self
.
num_key_value_heads
*
self
.
head_dim
)
self
.
out_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
True
)
if
yoco_cross
:
self
.
Wqkv
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
True
)
else
:
self
.
Wqkv
=
nn
.
Linear
(
self
.
hidden_size
,
op_size
,
bias
=
True
)
# disable sliding window for the second half of the model
is_sliding
=
config
.
layer_types
[
layer_idx
]
==
"sliding_attention"
sliding_window
=
config
.
sliding_window
if
is_sliding
else
None
assert
self
.
num_heads
%
2
==
0
,
'num_heads should be even'
assert
self
.
num_key_value_heads
%
2
==
0
,
'num_heads should be even'
self
.
lambda_init
=
self
.
lambda_init_fn
(
layer_idx
)
self
.
lambda_q1
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
head_dim
,
dtype
=
torch
.
float32
).
normal_
(
mean
=
0
,
std
=
0.1
))
self
.
lambda_k1
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
head_dim
,
dtype
=
torch
.
float32
).
normal_
(
mean
=
0
,
std
=
0.1
))
self
.
lambda_q2
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
head_dim
,
dtype
=
torch
.
float32
).
normal_
(
mean
=
0
,
std
=
0.1
))
self
.
lambda_k2
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
head_dim
,
dtype
=
torch
.
float32
).
normal_
(
mean
=
0
,
std
=
0.1
))
self
.
subln
=
nn
.
RMSNorm
(
2
*
self
.
head_dim
,
eps
=
1e-5
,
elementwise_affine
=
True
)
params
=
{
'differential_flash_attention_config'
:
{
'lambda_init'
:
self
.
lambda_init
,
'lambda_q1'
:
self
.
lambda_q1
,
'lambda_k1'
:
self
.
lambda_k1
,
'lambda_q2'
:
self
.
lambda_q2
,
'lambda_k2'
:
self
.
lambda_k2
,
"subln"
:
self
.
subln
,
}
}
if
yoco_cross
:
kv_shared_layer_index
=
config
.
num_hidden_layers
//
2
+
1
kv_sharing_target_layer_name
=
\
f
"model.layers.
{
kv_shared_layer_index
}
.self_attn.attn"
else
:
kv_sharing_target_layer_name
=
None
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
**-
0.5
,
num_kv_heads
=
self
.
num_key_value_heads
,
cache_config
=
cache_config
,
per_layer_sliding_window
=
sliding_window
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
,
**
params
)
assert
self
.
attn
.
backend
==
_Backend
.
DIFFERENTIAL_FLASH_ATTN
,
\
"DIFFERENTIAL_FLASH_ATTN required"
def
lambda_init_fn
(
self
,
depth
):
return
0.8
-
0.6
*
math
.
exp
(
-
0.3
*
depth
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
if
not
self
.
yoco_cross
:
# need to generate kv-cache
qkv
=
self
.
Wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
self
.
num_key_value_heads
*
self
.
head_dim
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
else
:
# reuse the kv cache, full attention
q
=
self
.
Wqkv
(
hidden_states
)
attn_output
=
self
.
attn
(
q
,
None
,
None
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
return
self
.
out_proj
(
attn_output
)
class
Phi4Mamba
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_state
=
16
,
d_conv
=
4
,
expand
=
2
,
dt_rank
=
"auto"
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init
=
"random"
,
# difference
dt_scale
=
1.0
,
# difference
dt_init_floor
=
1e-4
,
conv_bias
=
True
,
bias
=
False
,
use_fast_path
=
True
,
# Fused kernel options
layer_idx
=
None
,
device
=
None
,
dtype
=
None
,
yoco_cross
=
False
,
yoco_kv
=
False
,
):
factory_kwargs
=
{
"params_dtype"
:
dtype
}
# difference
super
().
__init__
()
self
.
yoco_cross
=
yoco_cross
self
.
yoco_kv
=
yoco_kv
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
expand
=
expand
self
.
d_inner
=
int
(
self
.
expand
*
self
.
d_model
)
self
.
dt_rank
=
math
.
ceil
(
self
.
d_model
/
16
)
if
dt_rank
==
"auto"
else
dt_rank
self
.
use_fast_path
=
use_fast_path
self
.
layer_idx
=
layer_idx
self
.
swiGluActivation
=
SwiGLUActivation
()
if
self
.
yoco_cross
:
self
.
in_proj
=
MergedColumnParallelLinear
(
self
.
d_model
,
[
self
.
d_inner
],
bias
=
bias
,
**
factory_kwargs
)
self
.
out_proj
=
RowParallelLinear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
**
factory_kwargs
)
return
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
d_conv
,
output_size
=
self
.
d_inner
,
bias
=
conv_bias
,
params_dtype
=
dtype
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
self
.
d_model
,
[
self
.
d_inner
]
*
2
,
bias
=
bias
,
params_dtype
=
dtype
,
)
# selective projection used to make dt, B and C input dependent
self
.
x_proj
=
RowParallelLinear
(
self
.
d_inner
,
self
.
dt_rank
+
self
.
d_state
*
2
,
bias
=
False
,
params_dtype
=
dtype
,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self
.
dt_proj
=
ColumnParallelLinear
(
self
.
dt_rank
,
self
.
d_inner
,
bias
=
True
,
skip_bias_add
=
True
,
params_dtype
=
dtype
,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
self
.
d_inner
,
self
.
d_state
,
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_inner
,
dtype
=
torch
.
float32
))
self
.
out_proj
=
RowParallelLinear
(
self
.
d_inner
,
self
.
d_model
,
bias
=
bias
,
input_is_parallel
=
True
,
params_dtype
=
dtype
,
)
self
.
activation
=
"silu"
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
yoco_key_values
=
None
)
->
torch
.
Tensor
:
if
self
.
yoco_cross
:
out
=
self
.
in_proj
(
hidden_states
)[
0
]
out
=
self
.
swiGluActivation
(
yoco_key_values
,
out
)
out
=
self
.
out_proj
(
out
)
return
out
[
0
],
yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states
=
self
.
in_proj
(
hidden_states
.
to
(
self
.
in_proj
.
weight
.
dtype
))[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
# z,
None
if
self
.
yoco_kv
else
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
torch
.
empty_like
(
hidden_states
.
transpose
(
0
,
1
))
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
# z
# gate.transpose(0, 1),
None
if
self
.
yoco_kv
else
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
,
out
=
scan_outputs
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
if
self
.
yoco_kv
:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values
=
scan_outputs
.
transpose
(
-
2
,
-
1
)
scan_outputs
=
self
.
swiGluActivation
(
scan_outputs
,
gate
)
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
,
yoco_key_values
class
SambaYDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
,
cache_config
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
mlp
=
SambaYMLP
(
config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
yoco_mb
=
False
self
.
yoco_cross
=
False
if
layer_idx
>=
config
.
num_hidden_layers
//
2
:
self
.
yoco_mb
=
True
self
.
yoco_cross
=
(
layer_idx
>=
(
config
.
num_hidden_layers
//
2
+
2
))
self
.
use_mamba
=
config
.
mb_per_layer
>
0
and
\
layer_idx
%
config
.
mb_per_layer
==
0
if
self
.
use_mamba
:
factory_kwargs
=
{
"dtype"
:
None
}
self
.
attn
=
Phi4Mamba
(
config
.
hidden_size
,
layer_idx
=
layer_idx
,
yoco_cross
=
self
.
yoco_cross
,
yoco_kv
=
self
.
yoco_mb
,
**
factory_kwargs
)
else
:
self
.
attn
=
SambaYAttention
(
config
,
layer_idx
=
layer_idx
,
yoco_cross
=
self
.
yoco_cross
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
ssm_output
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
self
.
use_mamba
:
assert
mamba_cache_params
is
not
None
else
:
assert
mamba_cache_params
is
None
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
.
to
(
dtype
=
self
.
input_layernorm
.
weight
.
dtype
))
if
self
.
use_mamba
:
attn_outputs
,
ssm_output
=
self
.
attn
(
hidden_states
,
attn_metadata
,
mamba_cache_params
,
yoco_key_values
=
ssm_output
)
residual
=
residual
.
to
(
torch
.
float32
)
else
:
attn_outputs
=
self
.
attn
(
hidden_states
,
)
hidden_states
=
residual
+
attn_outputs
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
.
to
(
dtype
=
self
.
post_attention_layernorm
.
weight
.
dtype
))
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
ssm_output
class
SambaYModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
=
None
,
quant_config
=
None
,
lora_config
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if
get_pp_group
().
world_size
!=
1
:
raise
ValueError
(
"Pipeline Parallel not supported"
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
SambaYDecoderLayer
(
config
,
int
(
prefix
.
split
(
'.'
)[
-
1
]),
cache_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
mamba_state_idx
=
0
ssm_output
=
None
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
if
i
==
self
.
config
.
num_hidden_layers
//
2
+
2
:
# profile run
kv_cache_idx
=
self
.
config
.
num_hidden_layers
//
2
+
1
cache_layer
=
self
.
layers
[
kv_cache_idx
]
kv_cache
=
cache_layer
.
attn
.
attn
.
kv_cache
if
kv_cache
[
0
].
numel
()
==
0
:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if
attn_metadata
.
prefill_metadata
and
not
envs
.
VLLM_USE_V1
:
selected_token_indices
=
torch
.
cumsum
(
attn_metadata
.
seq_lens_tensor
,
dim
=
0
)
-
1
hidden_states
=
hidden_states
.
index_select
(
0
,
selected_token_indices
)
ssm_output
=
ssm_output
.
index_select
(
0
,
selected_token_indices
)
if
layer
.
use_mamba
:
if
i
<
self
.
config
.
num_hidden_layers
//
2
or
\
not
layer
.
yoco_cross
:
mamba_cache
=
mamba_cache_params
.
at_layer_idx
(
mamba_state_idx
)
mamba_state_idx
+=
1
else
:
mamba_cache
=
mamba_cache_params
.
at_layer_idx
(
mamba_state_idx
-
1
)
hidden_states
,
ssm_output
=
layer
(
hidden_states
,
positions
,
attn_metadata
,
mamba_cache
,
ssm_output
=
ssm_output
)
else
:
hidden_states
,
ssm_output
=
layer
(
hidden_states
,
positions
,
attn_metadata
,
None
,
# mamba_cache_params
ssm_output
=
ssm_output
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
.
to
(
dtype
=
self
.
final_layernorm
.
weight
.
dtype
))
return
hidden_states
class
Phi4FlashForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsHybrid
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
lora_config
=
vllm_config
.
lora_config
quant_config
=
vllm_config
.
quant_config
scheduler_config
=
vllm_config
.
scheduler_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
vllm_config
=
vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert
not
cache_config
.
enable_prefix_caching
,
\
"Phi4flash currently does not support prefix caching"
assert
not
scheduler_config
.
chunked_prefill_enabled
,
\
"Phi4Flash currently does not support prefix caching"
super
().
__init__
()
self
.
config
=
config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
SambaYModel
(
config
,
cache_config
=
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
self
.
embedding_bias
=
None
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logits_as_input
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
self
.
config
.
num_hidden_layers
\
//
2
//
self
.
config
.
mb_per_layer
+
1
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
(),
self
.
lm_head
.
weight
.
dtype
,
self
.
lm_head
.
weight
.
dtype
,
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
attn_metadata
=
get_forward_context
().
attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states
=
self
.
model
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
_get_mamba_cache_shape
(
self
)
->
tuple
[
Optional
[
tuple
[
int
,
int
]],
Optional
[
tuple
[
int
,
int
]]]:
world_size
=
get_tensor_model_parallel_world_size
()
hidden_size
=
self
.
config
.
hidden_size
mamba_expand
=
self
.
config
.
mamba_expand
# 2
mamba_d_conv
=
self
.
config
.
mamba_d_conv
# 4
mamba_d_state
=
self
.
config
.
mamba_d_state
# 16
conv_state_shape
=
(
mamba_expand
*
hidden_size
//
world_size
,
mamba_d_conv
-
1
,
)
temporal_state_shape
=
(
mamba_expand
*
hidden_size
//
world_size
,
mamba_d_state
,
)
return
conv_state_shape
,
temporal_state_shape
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
Optional
[
torch
.
Tensor
]:
processed_logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
self
.
embedding_bias
,
)
return
processed_logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
):
weights
=
{
name
:
weight
for
name
,
weight
in
weights
}
adjusted_weights
=
{}
for
name
,
weight
in
weights
.
items
():
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
weight
=
-
torch
.
exp
(
weight
.
float
())
if
"inner_cross_attn."
in
name
:
name
=
name
.
replace
(
"inner_cross_attn."
,
""
)
adjusted_weights
[
name
]
=
weight
adjusted_weights
[
"lm_head.weight"
]
=
weights
[
"model.embed_tokens.weight"
]
loaded_params
:
set
[
str
]
=
set
()
for
name
,
param
in
self
.
named_parameters
():
weight
=
adjusted_weights
.
get
(
name
)
if
weight
is
not
None
and
weight
.
shape
!=
param
.
shape
:
logger
.
warning
(
"Shape mismatch: %s %s %s"
,
name
,
weight
.
shape
,
param
.
shape
)
loaded_params
.
add
(
name
)
missing_keys
,
unexpected_keys
=
self
.
load_state_dict
(
adjusted_weights
,
strict
=
False
)
assert
len
(
unexpected_keys
)
==
0
,
f
"Unexpected keys:
{
unexpected_keys
}
"
assert
len
(
missing_keys
)
==
0
,
f
"Missing keys:
{
missing_keys
}
"
return
loaded_params
vllm/model_executor/models/plamo2.py
View file @
a903669e
...
...
@@ -12,7 +12,6 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
...
...
@@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
,
update_metadata
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
...
...
@@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsPP
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.models.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionMetadata
...
...
@@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self
.
chunk_size
=
self
.
config
.
mamba_chunk_size
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self
.
kv_cache
=
[(
torch
.
tensor
([]),
torch
.
tensor
([]))]
assert
self
.
chunk_size
!=
-
1
,
"chunk_size must be set for v1"
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The tuple is (conv_state, ssm_state)
self
.
kv_cache
=
(
torch
.
tensor
([]),
torch
.
tensor
([]))
assert
self
.
chunk_size
!=
-
1
,
"chunk_size must be set for v1"
self
.
prefix
=
prefix
...
...
@@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
pass
...
...
@@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
not
envs
.
VLLM_USE_V1
:
CustomOp
.
forward
(
self
,
hidden_states
,
output
,
mamba_cache_params
,
mamba2_metadata
)
else
:
torch
.
ops
.
vllm
.
plamo2_mamba_mixer
(
hidden_states
,
output
,
self
.
prefix
,
)
torch
.
ops
.
vllm
.
plamo2_mamba_mixer
(
hidden_states
,
output
,
self
.
prefix
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
forward_context
=
get_forward_context
()
#
mamba2
_metadata contains metadata necessary for the mamba2 triton
#
attn
_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
mamba2_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
else
:
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
state_indices_tensor
=
mamba_cache_params
.
state_indices_tensor
# Common members between V1 metadata and V0 metadata
if
mamba2_metadata
is
not
None
:
has_initial_states_p
=
mamba2_metadata
.
has_initial_states_p
prep_initial_states
=
mamba2_metadata
.
prep_initial_states
chunk_size
=
mamba2_metadata
.
chunk_size
seq_idx_p
=
mamba2_metadata
.
seq_idx_p
chunk_indices_p
=
mamba2_metadata
.
chunk_indices_p
chunk_offsets_p
=
mamba2_metadata
.
chunk_offsets_p
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx_p
chunk_indices_p
=
attn_metadata
.
chunk_indices_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)
...
...
@@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
#
V1
profile run
if
attn_metadata
is
None
:
# profile run
hidden_states
=
(
hidden_states
.
transpose
(
0
,
1
).
clone
().
transpose
(
0
,
1
)).
contiguous
()
output
[:]
=
self
.
out_proj
(
hidden_states
)
...
...
@@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if
envs
.
VLLM_USE_V1
:
hidden_states_d
,
hidden_states_p
=
torch
.
split
(
hidden_states
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
else
:
hidden_states_p
,
hidden_states_d
=
torch
.
split
(
hidden_states
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
gate_p
,
gate_d
=
torch
.
split
(
gate
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
)
# Split along batch dimension
state_indices_tensor_p
,
state_indices_tensor_d
=
torch
.
split
(
state_indices_tensor
,
[
num_prefills
,
num_decodes
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[:
num_prefills
+
1
]
if
has_prefill
else
None
)
hidden_states_d
,
hidden_states_p
=
torch
.
split
(
hidden_states
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
...
...
@@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
if
envs
.
VLLM_USE_V1
:
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
else
:
preallocated_ssm_out_p
,
preallocated_ssm_out_d
=
torch
.
split
(
preallocated_ssm_out
,
[
num_prefill_tokens
,
num_decodes
],
dim
=
0
,
)
preallocated_ssm_out_d
,
preallocated_ssm_out_p
=
torch
.
split
(
preallocated_ssm_out
,
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Process prefill requests
if
has_prefill
:
...
...
@@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x
=
hidden_states_p
.
transpose
(
0
,
1
)
# this is the form that causal-conv see
if
mamba2_metadata
.
cu_seqlen
is
None
:
mamba2_metadata
=
update_metadata
(
x
,
query_start_loc_p
,
mamba2_metadata
)
hidden_states_p
=
causal_conv1d_fn
(
x
,
conv_weights
,
...
...
@@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
mamba2
_metadata
,
metadata
=
attn
_metadata
,
query_start_loc
=
query_start_loc_p
)
hidden_states_p
=
hidden_states_p
.
transpose
(
0
,
1
)
hidden_states_p
=
hidden_states_p
[:
num_prefill_tokens
]
...
...
@@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
-
1
,
self
.
num_heads
//
self
.
tp_size
,
self
.
head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# -
mamba_cache_params.
ssm_state's slots will be selected
# - ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
...
...
@@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
mamba_cache_params
=
None
,
mamba2_metadata
=
None
)
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
def
plamo2_mamba_mixer_fake
(
...
...
@@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
output
=
torch
.
empty_like
(
hidden_states
)
mixer_kwargs
=
{
"output"
:
output
,
"mamba_cache_params"
:
mamba_cache_params
,
"mamba2_metadata"
:
mamba2_metadata
,
}
else
:
mixer_kwargs
=
{
...
...
@@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
)
->
torch
.
Tensor
:
mamba_cache_index
=
0
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
layer_mamba_cache_params
=
None
if
layer
.
is_mamba
and
mamba_cache_params
is
not
None
:
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
mamba_cache_index
)
mamba_cache_index
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
return
hidden_states
,
residual
...
...
@@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
if
not
envs
.
VLLM_USE_V1
:
attn_metadata
:
AttentionMetadata
=
get_forward_context
(
).
attn_metadata
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
hidden_states
,
residual
=
self
.
layers
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
...
...
@@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
(
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
))
mamba_state_shape
=
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
else
:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params
=
None
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_param
s
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensor
s
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
@
classmethod
def
get_mamba_state_dtype_from_config
(
cls
,
...
...
@@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
...
...
@@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
head_dim
=
hf_config
.
hidden_size_per_head
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
def
compute_logits
(
...
...
vllm/model_executor/models/qwen3_next.py
View file @
a903669e
...
...
@@ -11,7 +11,6 @@ from einops import rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
vllm
import
envs
from
vllm.attention
import
Attention
,
AttentionBackend
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
...
...
@@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
...
...
@@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
...
...
@@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
return
MambaStateShapeCalculator
.
gated_delta_net_state_shape
(
self
.
tp_size
,
self
.
num_k_heads
,
self
.
num_v_heads
,
self
.
head_k_dim
,
self
.
head_v_dim
,
self
.
conv_kernel_size
,
self
.
num_spec
,
use_v1
=
True
)
self
.
tp_size
,
self
.
num_k_heads
,
self
.
num_v_heads
,
self
.
head_k_dim
,
self
.
head_v_dim
,
self
.
conv_kernel_size
,
self
.
num_spec
)
def
__init__
(
self
,
...
...
@@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
cache_params
:
Optional
[
MambaCacheParams
]
=
None
,
):
return
torch
.
ops
.
vllm
.
gdn_attention
(
hidden_states
,
...
...
@@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
GDNAttentionMetadata
)
has_initial_state
=
attn_metadata
.
has_initial_state
spec_query_start_loc
=
attn_metadata
.
spec_query_start_loc
...
...
@@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part
if
attn_metadata
.
num_prefills
>
0
:
mixed_qkv_non_spec_T
=
mixed_qkv_non_spec
.
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
mixed_qkv_non_spec_T
,
non_spec_query_start_loc
,
conv_metadata
)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "
mamba_cache_params.
state_indices_tensor"
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec
=
causal_conv1d_fn
(
mixed_qkv_non_spec_T
,
conv_weights
,
...
...
@@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
conv
_metadata
,
metadata
=
attn
_metadata
,
).
transpose
(
0
,
1
)
elif
attn_metadata
.
num_decodes
>
0
:
mixed_qkv_non_spec
=
causal_conv1d_update
(
...
...
@@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config
=
vllm_config
.
scheduler_config
assert
not
cache_config
.
enable_prefix_caching
,
\
"Qwen3Next currently does not support prefix caching"
assert
envs
.
VLLM_USE_V1
,
"Qwen3Next requires VLLM_USE_V1"
self
.
quant_config
=
vllm_config
.
quant_config
super
().
__init__
()
...
...
@@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_spec
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
if
vllm_config
.
speculative_config
else
0
)
return
MambaStateShapeCalculator
.
gated_delta_net_state_shape
(
tp_size
,
hf_config
.
linear_num_key_heads
,
hf_config
.
linear_num_value_heads
,
hf_config
.
linear_key_head_dim
,
hf_config
.
linear_value_head_dim
,
hf_config
.
linear_conv_kernel_dim
,
num_spec
,
use_v1
=
True
)
tp_size
,
hf_config
.
linear_num_key_heads
,
hf_config
.
linear_num_value_heads
,
hf_config
.
linear_key_head_dim
,
hf_config
.
linear_value_head_dim
,
hf_config
.
linear_conv_kernel_dim
,
num_spec
)
def
compute_logits
(
self
,
...
...
vllm/model_executor/models/registry.py
View file @
a903669e
...
...
@@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
"Phi4FlashForCausalLM"
:
(
"phi4flash"
,
"Phi4FlashForCausalLM"
),
"Plamo2ForCausalLM"
:
(
"plamo2"
,
"Plamo2ForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
...
...
vllm/model_executor/models/zamba2.py
View file @
a903669e
...
...
@@ -15,12 +15,10 @@ import torch
from
torch
import
nn
from
transformers
import
Zamba2Config
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
...
...
@@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
...
...
@@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
transformer_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
original_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
Args:
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba)
...
...
@@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
self
.
mamba
(
hidden_states
,
output
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
# residual connection after mamba
...
...
@@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
original_hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
)
->
torch
.
Tensor
:
"""Forward pass through the hybrid layer.
...
...
@@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
original_hidden_states: Original input for transformer residual
connection
positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
Returns:
Output tensor combining transformer and Mamba representations
...
...
@@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
layer_outputs
=
self
.
mamba_decoder
(
hidden_states
,
transformer_hidden_states
=
transformer_hidden_states
,
mamba_cache_params
=
mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
return
layer_outputs
...
...
@@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""Forward pass through the model.
...
...
@@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
Args:
input_ids: Input token IDs
positions: Position IDs for embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
inputs_embeds: Optional pre-computed input embeddings
Returns:
...
...
@@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
=
inputs_embeds
attn_metadata
=
get_forward_context
().
attn_metadata
if
not
envs
.
VLLM_USE_V1
:
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
chunk_size
,
attn_metadata
=
attn_metadata
,
)
else
:
# v1 get mamba2_metadata from forward_context
mamba2_metadata
=
None
# Process through layers
original_hidden_states
=
torch
.
clone
(
hidden_states
)
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
):
layer_mamba_cache_params
=
None
if
(
isinstance
(
layer
,
(
Zamba2HybridLayer
,
Zamba2MambaDecoderLayer
))
and
mamba_cache_params
):
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
layer_idx
)
layer_outputs
=
layer
(
hidden_states
,
original_hidden_states
=
original_hidden_states
,
positions
=
positions
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba2_metadata
=
mamba2_metadata
,
)
hidden_states
=
layer_outputs
...
...
@@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
,
int
]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
...
...
@@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
head_dim
=
hf_config
.
mamba_headdim
,
state_size
=
hf_config
.
mamba_d_state
,
conv_kernel
=
hf_config
.
mamba_d_conv
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
...
...
@@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
# Tie weights with input embeddings if using same dimensions
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
# Initialize logits processing and sampling
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns:
Output hidden states
"""
# Initialize Mamba cache if needed
mamba_cache_params
=
None
if
not
envs
.
VLLM_USE_V1
:
if
self
.
mamba_cache
is
None
:
num_mamba_layers
=
self
.
config
.
num_hidden_layers
mamba_state_shape
=
\
self
.
get_mamba_state_shape_from_config
(
self
.
vllm_config
,
use_v1
=
False
)
mamba_state_dtype
=
\
self
.
get_mamba_state_dtype_from_config
(
self
.
vllm_config
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
vllm_config
,
num_mamba_layers
,
*
mamba_state_shape
,
*
mamba_state_dtype
)
# Get cache parameters for current run
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
# Forward pass through model
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
inputs_embeds
,
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
:
dict
[
str
,
torch
.
Tensor
],
**
kwargs
:
Any
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Copy inputs before CUDA graph capture.
Args:
input_buffers: Dictionary of input tensors
**kwargs: Additional arguments passed to cache manager
Returns:
Updated input buffers
"""
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
batch_size: Size of batch to capture
Returns:
Dictionary of capture inputs
"""
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
a903669e
...
...
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
...
...
@@ -52,7 +53,6 @@ class GDNAttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
context_lens
=
m
.
num_computed_tokens_cpu
context_lens_tensor
=
context_lens
.
to
(
query_start_loc
.
device
)
seq_lens_tensor
=
m
.
seq_lens
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
if
(
not
self
.
use_spec_decode
or
num_draft_tokens
is
None
or
num_draft_tokens
.
sum
().
item
()
==
0
):
...
...
@@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
has_initial_state
=
context_lens_tensor
>
0
if
spec_sequence_masks
is
not
None
:
has_initial_state
=
has_initial_state
[
~
spec_sequence_masks
]
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
compute_causal_conv1d_metadata
(
non_spec_query_start_loc
)
else
:
has_initial_state
=
None
num_actual_tokens
=
num_prefill_tokens
+
num_decode_tokens
+
\
...
...
@@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
spec_sequence_masks
=
spec_sequence_masks
,
spec_token_masks
=
spec_token_masks
,
num_accepted_tokens
=
num_accepted_tokens
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
return
attn_metadata
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
a903669e
...
...
@@ -7,11 +7,12 @@ from typing import Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states_p
=
None
prep_initial_states
=
False
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
...
...
@@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p
,
self
.
chunk_size
,
num_prefill_tokens
))
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
compute_causal_conv1d_metadata
(
query_start_loc_p
)
elif
num_decodes
<=
self
.
decode_cudagraph_max_bs
:
# Pad state tensor for CUDA graph
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
...
...
@@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices_p
=
chunk_indices_p
,
chunk_offsets_p
=
chunk_offsets_p
,
state_indices_tensor
=
state_indices_tensor
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
return
attn_metadata
vllm/v1/attention/backends/short_conv_attn.py
View file @
a903669e
...
...
@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
...
...
@@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
# For causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
...
...
@@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
compute_causal_conv1d_metadata
(
query_start_loc_p
)
attn_metadata
=
ShortConvAttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
...
...
@@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
query_start_loc
=
query_start_loc
,
has_initial_states
=
has_initial_states
,
state_indices_tensor
=
state_indices_tensor
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
return
attn_metadata
vllm/v1/attention/backends/utils.py
View file @
a903669e
...
...
@@ -34,6 +34,8 @@ logger = init_logger(__name__)
KVCacheLayoutType
=
Literal
[
"NHD"
,
"HND"
]
_KV_CACHE_LAYOUT_OVERRIDE
:
Union
[
KVCacheLayoutType
,
None
]
=
None
PAD_SLOT_ID
=
-
1
def
is_valid_kv_cache_layout
(
value
:
str
)
->
bool
:
return
value
in
get_args
(
KVCacheLayoutType
)
...
...
@@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
builder_cls
=
FastPrefillAttentionBuilder
)
return
attn_backend
def
compute_causal_conv1d_metadata
(
query_start_loc_p
:
torch
.
Tensor
):
# Needed for causal_conv1d
seqlens
=
query_start_loc_p
.
diff
().
to
(
'cpu'
)
nums_dict
=
{}
# type: ignore
batch_ptr
=
None
token_chunk_offset_ptr
=
None
for
BLOCK_M
in
[
8
]:
# cover all BLOCK_M values
nums
=
-
(
-
seqlens
//
BLOCK_M
)
nums_dict
[
BLOCK_M
]
=
{}
nums_dict
[
BLOCK_M
][
'nums'
]
=
nums
nums_dict
[
BLOCK_M
][
'tot'
]
=
nums
.
sum
().
item
()
mlist
=
torch
.
from_numpy
(
np
.
repeat
(
np
.
arange
(
len
(
nums
)),
nums
))
nums_dict
[
BLOCK_M
][
'mlist'
]
=
mlist
mlist_len
=
len
(
nums_dict
[
BLOCK_M
][
'mlist'
])
nums_dict
[
BLOCK_M
][
'mlist_len'
]
=
mlist_len
MAX_NUM_PROGRAMS
=
max
(
1024
,
mlist_len
)
*
2
offsetlist
=
[]
# type: ignore
for
idx
,
num
in
enumerate
(
nums
):
offsetlist
.
extend
(
range
(
num
))
offsetlist
=
torch
.
tensor
(
offsetlist
,
dtype
=
torch
.
int32
)
nums_dict
[
BLOCK_M
][
'offsetlist'
]
=
offsetlist
if
batch_ptr
is
None
:
# Update default value after class definition
batch_ptr
=
torch
.
full
((
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
token_chunk_offset_ptr
=
torch
.
full
((
MAX_NUM_PROGRAMS
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
else
:
if
batch_ptr
.
nelement
()
<
MAX_NUM_PROGRAMS
:
batch_ptr
.
resize_
(
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
token_chunk_offset_ptr
.
resize_
(
# type: ignore
MAX_NUM_PROGRAMS
).
fill_
(
PAD_SLOT_ID
)
batch_ptr
[
0
:
mlist_len
].
copy_
(
mlist
)
token_chunk_offset_ptr
[
# type: ignore
0
:
mlist_len
].
copy_
(
offsetlist
)
nums_dict
[
BLOCK_M
][
'batch_ptr'
]
=
batch_ptr
nums_dict
[
BLOCK_M
][
'token_chunk_offset_ptr'
]
=
(
token_chunk_offset_ptr
)
# type: ignore
return
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment