Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ade99ea
Unverified
Commit
6ade99ea
authored
Aug 09, 2025
by
Thomas Parnell
Committed by
GitHub
Aug 08, 2025
Browse files
[V1] [Hybrid] Support Minimax-Text-01 in V1 (#22151)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
3157aebb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
234 additions
and
42 deletions
+234
-42
vllm/model_executor/layers/lightning_attn.py
vllm/model_executor/layers/lightning_attn.py
+1
-1
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+11
-0
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+152
-40
vllm/v1/attention/backends/linear_attn.py
vllm/v1/attention/backends/linear_attn.py
+67
-0
vllm/v1/attention/backends/mamba_selectors.py
vllm/v1/attention/backends/mamba_selectors.py
+3
-1
No files found.
vllm/model_executor/layers/lightning_attn.py
View file @
6ade99ea
...
@@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
...
@@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
pid_d
=
tl
.
program_id
(
2
)
# dimension block index
pid_d
=
tl
.
program_id
(
2
)
# dimension block index
# Load slot index for the current batch
# Load slot index for the current batch
slot_id
=
tl
.
load
(
slot_idx
+
pid_b
)
slot_id
=
tl
.
load
(
slot_idx
+
pid_b
)
.
to
(
tl
.
int64
)
# Skip if slot_id is -1 (padding)
# Skip if slot_id is -1 (padding)
if
slot_id
==
-
1
:
if
slot_id
==
-
1
:
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
6ade99ea
...
@@ -5,6 +5,17 @@ from vllm.distributed import divide
...
@@ -5,6 +5,17 @@ from vllm.distributed import divide
class
MambaStateShapeCalculator
:
class
MambaStateShapeCalculator
:
@
classmethod
def
linear_attention_state_shape
(
cls
,
num_heads
:
int
,
tp_size
:
int
,
head_dim
:
int
,
)
->
tuple
[
tuple
[
int
,
int
,
int
],
...]:
state_shape
=
(
num_heads
//
tp_size
,
head_dim
,
head_dim
)
return
(
state_shape
,
)
@
classmethod
@
classmethod
def
mamba1_state_shape
(
def
mamba1_state_shape
(
cls
,
cls
,
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
6ade99ea
...
@@ -14,8 +14,9 @@ from einops import rearrange
...
@@ -14,8 +14,9 @@ from einops import rearrange
from
torch
import
nn
from
torch
import
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm
import
envs
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionMetadata
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsV0Only
from
.interfaces
import
HasInnerState
,
IsHybrid
from
.minimax_cache
import
MinimaxCacheManager
,
MinimaxCacheParams
from
.minimax_cache
import
MinimaxCacheManager
,
MinimaxCacheParams
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
@@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
...
@@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
return
rearrange
(
output
.
squeeze
(
0
),
"h n d -> n (h d)"
)
return
rearrange
(
output
.
squeeze
(
0
),
"h n d -> n (h d)"
)
class
MiniMaxText01LinearAttention
(
nn
.
Module
):
class
MiniMaxText01LinearAttention
(
nn
.
Module
,
MambaBase
):
@
property
def
mamba_type
(
self
)
->
str
:
return
"linear_attention"
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
return
MambaStateShapeCalculator
.
linear_attention_state_shape
(
num_heads
=
self
.
num_heads
,
tp_size
=
self
.
tp_size
,
head_dim
=
self
.
head_dim
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
self
.
tp_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
tp_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
qkv_size
=
self
.
num_heads
*
self
.
head_dim
self
.
qkv_size
=
self
.
num_heads
*
self
.
head_dim
self
.
tp_hidden
=
self
.
head_dim
*
self
.
tp_heads
self
.
tp_hidden
=
self
.
head_dim
*
self
.
tp_heads
self
.
prefix
=
prefix
self
.
qkv_proj
=
ColumnParallelLinear
(
self
.
qkv_proj
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
self
.
tp_heads
:(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
:(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
].
contiguous
()
self
.
tp_heads
].
contiguous
()
if
envs
.
VLLM_USE_V1
:
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
@
staticmethod
@
staticmethod
def
weight_direct_load
(
param
:
torch
.
Tensor
,
def
weight_direct_load
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loaded_weight
:
torch
.
Tensor
)
->
None
:
...
@@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
break
break
if
_prefill_idx
>=
len
(
state_indices_tensor
):
if
_prefill_idx
>=
len
(
state_indices_tensor
):
break
break
_start
=
attn_metadata
.
query_start_loc
[
_prefill_idx
]
# prefills are packed at end of batch in V1
_end
=
attn_metadata
.
query_start_loc
[
_prefill_idx
+
1
]
offset
=
attn_metadata
.
num_decode_tokens
if
envs
.
VLLM_USE_V1
else
0
slot_id
=
state_indices_tensor
[
_prefill_idx
]
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
slot_id
=
state_indices_tensor
[
_prefill_idx
]
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
out_slice
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
(
out_slice
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
(
...
@@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
layer_idx
=
self
.
layer_idx
)
layer_idx
=
self
.
layer_idx
)
hidden
.
append
(
out_slice
.
contiguous
())
hidden
.
append
(
out_slice
.
contiguous
())
if
attn_metadata
.
num_decode_tokens
>
0
:
if
attn_metadata
.
num_decode_tokens
>
0
:
hidden
.
append
(
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
state_indices_tensor
,
attn_metadata
))
attn_metadata
)
if
envs
.
VLLM_USE_V1
:
hidden
.
insert
(
0
,
hidden_decode
)
else
:
hidden
.
append
(
hidden_decode
)
if
not
hidden
:
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
...
@@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
attn_metadata
):
q
=
q
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
if
not
envs
.
VLLM_USE_V1
:
k
=
k
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
q
=
q
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
v
=
v
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
k
=
k
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[
getattr
(
attn_metadata
,
"num_prefills"
,
0
v
=
v
[
attn_metadata
.
num_prefill_tokens
:].
unsqueeze
(
2
).
contiguous
()
):]
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
slot_id
=
state_indices_tensor
[
num_prefills
:]
else
:
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
hidden
=
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
hidden
=
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
slot_id
,
32
)
slot_id
,
32
)
return
hidden
return
hidden
...
@@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
...
@@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
kv_cache
=
kv_caches
.
minimax_cache
if
envs
.
VLLM_USE_V1
:
state_indices_tensor
=
kv_caches
.
state_indices_tensor
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
else
:
kv_cache
=
kv_caches
.
minimax_cache
state_indices_tensor
=
kv_caches
.
state_indices_tensor
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
not
decode_only
:
if
attn_metadata
is
None
:
hidden
=
self
.
_prefill_and_mix_infer
(
q
,
k
,
v
,
kv_cache
,
hidden
=
torch
.
empty
((
q
.
shape
[
0
],
q
.
shape
[
1
]
*
q
.
shape
[
2
])
,
state_indices_tensor
,
device
=
q
.
device
,
attn_metadata
)
dtype
=
q
.
dtype
)
else
:
else
:
hidden
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
if
not
decode_only
:
state_indices_tensor
,
attn_metadata
)
hidden
=
self
.
_prefill_and_mix_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
else
:
hidden
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
hidden
=
self
.
norm
.
_forward
(
hidden
)
hidden
=
self
.
norm
.
_forward
(
hidden
)
gate
,
_
=
self
.
output_gate
(
hidden_states
)
gate
,
_
=
self
.
output_gate
(
hidden_states
)
...
@@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
...
@@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
prefix
=
prefix
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
...
@@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
attn_metadata
.
rotary_emb
(
positions
,
q
,
k
)
if
envs
.
VLLM_USE_V1
:
if
attn_metadata
is
not
None
:
q
,
k
=
attn_metadata
[
f
"
{
self
.
prefix
}
.attn"
].
rotary_emb
(
positions
,
q
,
k
)
else
:
q
,
k
=
attn_metadata
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
...
@@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
)
->
None
:
)
->
None
:
self
.
_ilayer
=
layer_id
self
.
_ilayer
=
layer_id
self
.
_irank
=
get_tensor_model_parallel_rank
()
self
.
_irank
=
get_tensor_model_parallel_rank
()
self
.
prefix
=
prefix
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
...
@@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
self
.
_dtype
=
_dummy
.
dtype
self
.
_dtype
=
_dummy
.
dtype
del
_dummy
del
_dummy
self
.
minimax_cache
=
MinimaxCacheManager
(
dtype
=
torch
.
float32
,
if
not
envs
.
VLLM_USE_V1
:
cache_shape
=
self
.
cache_shape
)
self
.
minimax_cache
=
MinimaxCacheManager
(
dtype
=
torch
.
float32
,
cache_shape
=
self
.
cache_shape
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
...
@@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
...
@@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
**
kwargs
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
**
kwargs
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
if
not
envs
.
VLLM_USE_V1
and
attn_metadata
is
None
:
return
None
return
None
if
"request_ids_to_seq_ids"
not
in
kwargs
:
if
"request_ids_to_seq_ids"
not
in
kwargs
:
kwargs
[
"request_ids_to_seq_ids"
]
=
{}
kwargs
[
"request_ids_to_seq_ids"
]
=
{}
if
"finished_requests_ids"
not
in
kwargs
:
if
"finished_requests_ids"
not
in
kwargs
:
kwargs
[
"finished_requests_ids"
]
=
[]
kwargs
[
"finished_requests_ids"
]
=
[]
(
if
not
envs
.
VLLM_USE_V1
:
minimax_cache_tensors
,
(
state_indices_tensor
,
minimax_cache_tensors
,
)
=
self
.
minimax_cache
.
current_run_tensors
(
**
kwargs
)
state_indices_tensor
,
if
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
>
0
:
)
=
self
.
minimax_cache
.
current_run_tensors
(
**
kwargs
)
self
.
_clear_prefill_cache
(
attn_metadata
,
minimax_cache_tensors
,
if
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
>
0
:
**
kwargs
)
self
.
_clear_prefill_cache
(
attn_metadata
,
minimax_cache_tensors
,
**
kwargs
)
minimax_cache_params
=
MinimaxCacheParams
(
minimax_cache_tensors
,
state_indices_tensor
)
else
:
minimax_cache_params
=
None
minimax_cache_params
=
MinimaxCacheParams
(
minimax_cache_tensors
,
state_indices_tensor
)
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
hidden_states
=
self
.
embed_scale
*
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_scale
*
self
.
embed_tokens
(
input_ids
)
...
@@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
...
@@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
minimax_cache_index
=
0
minimax_cache_index
=
0
attn_metadata
.
rotary_emb
=
self
.
rotary_emb
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
if
attn_metadata
is
not
None
:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if
envs
.
VLLM_USE_V1
:
if
isinstance
(
layer
.
self_attn
,
MiniMaxText01Attention
):
attn_metadata
[
layer
.
prefix
+
".attn"
].
rotary_emb
=
self
.
rotary_emb
else
:
attn_metadata
.
rotary_emb
=
self
.
rotary_emb
_caches
=
None
_caches
=
None
if
isinstance
(
layer
.
self_attn
,
MiniMaxText01LinearAttention
):
if
not
envs
.
VLLM_USE_V1
and
isinstance
(
layer
.
self_attn
,
MiniMaxText01LinearAttention
):
current_state_layer
=
minimax_cache_index
current_state_layer
=
minimax_cache_index
_caches
=
minimax_cache_params
.
at_layer_idx
(
_caches
=
minimax_cache_params
.
at_layer_idx
(
current_state_layer
)
current_state_layer
)
...
@@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
...
@@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
return
hidden_states
return
hidden_states
class
MiniMaxText01ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsHybrid
,
class
MiniMaxText01ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsHybrid
):
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
...
@@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
...
@@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
load_basic_weight
(
name
,
loaded_weight
,
self
)
load_basic_weight
(
name
,
loaded_weight
,
self
)
return
loaded_params
return
loaded_params
@
classmethod
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
...],
...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- state_shape: Shape of the cache
"""
parallel_config
=
vllm_config
.
parallel_config
hf_config
=
vllm_config
.
model_config
.
hf_config
return
MambaStateShapeCalculator
.
linear_attention_state_shape
(
num_heads
=
hf_config
.
num_attention_heads
,
tp_size
=
parallel_config
.
tensor_parallel_size
,
head_dim
=
hf_config
.
head_dim
,
)
vllm/v1/attention/backends/linear_attn.py
0 → 100644
View file @
6ade99ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
ClassVar
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
class
LinearAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_builder_cls
()
->
type
[
"LinearAttentionMetadataBuilder"
]:
return
LinearAttentionMetadataBuilder
@
dataclass
class
LinearAttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
class
LinearAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
LinearAttentionMetadata
]):
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
self
.
kv_cache_spec
=
kv_cache_spec
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
)
->
LinearAttentionMetadata
:
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
1
))
attn_metadata
=
LinearAttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
state_indices_tensor
=
state_indices_tensor
,
)
return
attn_metadata
vllm/v1/attention/backends/mamba_selectors.py
View file @
6ade99ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
...
@@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
...
@@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
def
get_mamba_attn_backend
(
mamba_type
:
str
)
->
type
[
AttentionBackend
]:
def
get_mamba_attn_backend
(
mamba_type
:
str
)
->
type
[
AttentionBackend
]:
if
mamba_type
==
"mamba1"
:
if
mamba_type
==
"mamba1"
:
return
Mamba1AttentionBackend
return
Mamba1AttentionBackend
if
mamba_type
==
"mamba2"
:
if
mamba_type
==
"mamba2"
:
return
Mamba2AttentionBackend
return
Mamba2AttentionBackend
if
mamba_type
==
"linear_attention"
:
return
LinearAttentionBackend
raise
NotImplementedError
(
f
"Mamba Attention type
{
mamba_type
}
is not "
raise
NotImplementedError
(
f
"Mamba Attention type
{
mamba_type
}
is not "
"supported yet."
)
"supported yet."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment