Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
922f3164
Unverified
Commit
922f3164
authored
Jul 11, 2025
by
Michael Goin
Committed by
GitHub
Jul 11, 2025
Browse files
[Model] Support HF format of minimax (#20211)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
5923ab95
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
11 deletions
+36
-11
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+33
-11
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
tests/models/registry.py
View file @
922f3164
...
@@ -218,6 +218,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -218,6 +218,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MiniCPM3ForCausalLM"
:
_HfExamplesInfo
(
"openbmb/MiniCPM3-4B"
,
"MiniCPM3ForCausalLM"
:
_HfExamplesInfo
(
"openbmb/MiniCPM3-4B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MiniMaxForCausalLM"
:
_HfExamplesInfo
(
"MiniMaxAI/MiniMax-Text-01-hf"
,
min_transformers_version
=
"4.53"
),
"MiniMaxText01ForCausalLM"
:
_HfExamplesInfo
(
"MiniMaxAI/MiniMax-Text-01"
,
"MiniMaxText01ForCausalLM"
:
_HfExamplesInfo
(
"MiniMaxAI/MiniMax-Text-01"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
revision
=
"a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"
),
# noqa: E501
revision
=
"a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"
),
# noqa: E501
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
922f3164
...
@@ -667,16 +667,24 @@ class MiniMaxText01DecoderLayer(nn.Module):
...
@@ -667,16 +667,24 @@ class MiniMaxText01DecoderLayer(nn.Module):
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
if
config
.
attention_type
==
0
:
if
config
.
attention_type
==
0
:
self
.
layernorm_attention_alpha
=
getattr
(
self
.
layernorm_attention_alpha
=
getattr
(
config
,
'layernorm_linear_attention_alpha'
,
1
)
config
,
'layernorm_linear_attention_alpha'
,
getattr
(
config
,
'linear_attn_alpha_factor'
,
1
))
self
.
layernorm_attention_beta
=
getattr
(
self
.
layernorm_attention_beta
=
getattr
(
config
,
'layernorm_linear_attention_beta'
,
1
)
config
,
'layernorm_linear_attention_beta'
,
getattr
(
config
,
'linear_attn_beta_factor'
,
1
))
else
:
else
:
self
.
layernorm_attention_alpha
=
getattr
(
self
.
layernorm_attention_alpha
=
getattr
(
config
,
'layernorm_full_attention_alpha'
,
1
)
config
,
'layernorm_full_attention_alpha'
,
getattr
(
config
,
'full_attn_alpha_factor'
,
1
))
self
.
layernorm_attention_beta
=
getattr
(
self
.
layernorm_attention_beta
=
getattr
(
config
,
'layernorm_full_attention_beta'
,
1
)
config
,
'layernorm_full_attention_beta'
,
self
.
layernorm_mlp_alpha
=
getattr
(
config
,
'layernorm_mlp_alpha'
,
1
)
getattr
(
config
,
'full_attn_beta_factor'
,
1
))
self
.
layernorm_mlp_beta
=
getattr
(
config
,
'layernorm_mlp_beta'
,
1
)
self
.
layernorm_mlp_alpha
=
getattr
(
config
,
'layernorm_mlp_alpha'
,
getattr
(
config
,
'mlp_alpha_factor'
,
1
))
self
.
layernorm_mlp_beta
=
getattr
(
config
,
'layernorm_mlp_beta'
,
getattr
(
config
,
'mlp_beta_factor'
,
1
))
self
.
postnorm
=
getattr
(
config
,
'postnorm'
,
False
)
self
.
postnorm
=
getattr
(
config
,
'postnorm'
,
False
)
self
.
shared_moe
=
False
self
.
shared_moe
=
False
...
@@ -794,6 +802,18 @@ class MiniMaxText01Model(nn.Module):
...
@@ -794,6 +802,18 @@ class MiniMaxText01Model(nn.Module):
self
.
decoder_attention_types
=
getattr
(
self
.
decoder_attention_types
=
getattr
(
config
,
"attn_type_list"
,
False
)
or
getattr
(
config
,
"attn_type_list"
,
False
)
or
getattr
(
config
,
"decoder_attention_types"
,
False
)
config
,
"decoder_attention_types"
,
False
)
# The HF format uses "layer_types" instead of "attn_type_list"
# where "linear_attention" is 0 and "full_attention" is 1
if
not
self
.
decoder_attention_types
and
hasattr
(
config
,
"layer_types"
):
self
.
decoder_attention_types
=
[]
for
layer_type
in
config
.
layer_types
:
if
layer_type
==
"linear_attention"
:
self
.
decoder_attention_types
.
append
(
0
)
elif
layer_type
==
"full_attention"
:
self
.
decoder_attention_types
.
append
(
1
)
else
:
raise
ValueError
(
f
"Unsupported layer type:
{
layer_type
}
"
)
# Default to full attention
if
not
self
.
decoder_attention_types
:
if
not
self
.
decoder_attention_types
:
self
.
decoder_attention_types
=
[
1
]
*
config
.
num_hidden_layers
self
.
decoder_attention_types
=
[
1
]
*
config
.
num_hidden_layers
self
.
num_layers
=
config
.
num_hidden_layers
self
.
num_layers
=
config
.
num_hidden_layers
...
@@ -1022,8 +1042,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
...
@@ -1022,8 +1042,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
.
float
()
self
.
lm_head
.
float
()
flash_layer_count
=
sum
(
1
for
attn_type
in
self
.
config
.
attn_type_list
flash_layer_count
=
sum
(
if
attn_type
==
1
)
1
for
attn_type
in
self
.
model
.
decoder_attention_types
if
attn_type
==
1
)
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
flash_layer_count
)]
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
flash_layer_count
)]
return
return
...
@@ -1085,9 +1106,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
...
@@ -1085,9 +1106,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
return
None
return
None
def
is_linear_attn_layer
(
layer_idx
:
int
)
->
bool
:
def
is_linear_attn_layer
(
layer_idx
:
int
)
->
bool
:
if
layer_idx
is
None
or
not
hasattr
(
self
.
config
,
"attn_type_list"
):
if
layer_idx
is
None
or
layer_idx
>=
len
(
self
.
model
.
decoder_attention_types
):
return
False
return
False
return
self
.
config
.
att
n_type
_list
[
layer_idx
]
==
0
return
self
.
model
.
decoder_attentio
n_type
s
[
layer_idx
]
==
0
def
is_moe_weight
(
name
:
str
)
->
bool
:
def
is_moe_weight
(
name
:
str
)
->
bool
:
return
"block_sparse_moe"
in
name
and
not
name
.
endswith
(
".bias"
)
return
"block_sparse_moe"
in
name
and
not
name
.
endswith
(
".bias"
)
...
@@ -1275,7 +1297,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
...
@@ -1275,7 +1297,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
weight_at_layer
=
which_layer
(
name
)
weight_at_layer
=
which_layer
(
name
)
if
weight_at_layer
and
weight_at_layer
>=
len
(
if
weight_at_layer
and
weight_at_layer
>=
len
(
self
.
config
.
att
n_type
_list
):
self
.
model
.
decoder_attentio
n_type
s
):
continue
continue
if
is_layer_norm_weight
(
name
):
if
is_layer_norm_weight
(
name
):
...
...
vllm/model_executor/models/registry.py
View file @
922f3164
...
@@ -34,6 +34,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -34,6 +34,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"ArcticForCausalLM"
:
(
"arctic"
,
"ArcticForCausalLM"
),
"ArcticForCausalLM"
:
(
"arctic"
,
"ArcticForCausalLM"
),
"MiniMaxForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxText01ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxText01ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxM1ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxM1ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
# baichuan-7b, upper case 'C' in the class name
# baichuan-7b, upper case 'C' in the class name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment