Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
63babd17
Unverified
Commit
63babd17
authored
Mar 30, 2026
by
PikaPikachu
Committed by
GitHub
Mar 30, 2026
Browse files
[Model][Quantization] Add GGUF support for MiniMax-M2.1 (#36965)
Signed-off-by:
kangletian
<
Letian.Kang@amd.com
>
parent
fec5aeca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
137 additions
and
10 deletions
+137
-10
vllm/config/model.py
vllm/config/model.py
+1
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+14
-1
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+74
-7
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+43
-0
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+5
-2
No files found.
vllm/config/model.py
View file @
63babd17
...
@@ -948,6 +948,7 @@ class ModelConfig:
...
@@ -948,6 +948,7 @@ class ModelConfig:
# imports during override detection (e.g., MXFP4 imports Triton)
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4"
,
"mxfp4"
,
"cpu_awq"
,
"cpu_awq"
,
"gguf"
,
]
]
quantization_methods
=
[
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
63babd17
...
@@ -3,7 +3,10 @@
...
@@ -3,7 +3,10 @@
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
types
import
MappingProxyType
from
types
import
MappingProxyType
from
typing
import
Any
from
typing
import
TYPE_CHECKING
,
Any
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
import
gguf
import
gguf
import
torch
import
torch
...
@@ -79,6 +82,16 @@ class GGUFConfig(QuantizationConfig):
...
@@ -79,6 +82,16 @@ class GGUFConfig(QuantizationConfig):
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"GGUFConfig"
:
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"GGUFConfig"
:
return
cls
()
return
cls
()
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
:
dict
[
str
,
Any
],
user_quant
:
str
|
None
)
->
"QuantizationMethods | None"
:
# When user explicitly specifies --quantization gguf, override
# whatever quantization method is in the HF model config (e.g. fp8).
if
user_quant
==
"gguf"
:
return
"gguf"
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
"QuantizeMethodBase | None"
:
)
->
"QuantizeMethodBase | None"
:
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
63babd17
...
@@ -24,6 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -24,6 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names
,
get_gguf_extra_tensor_names
,
get_gguf_weight_type_map
,
get_gguf_weight_type_map
,
gguf_quant_weights_iterator
,
gguf_quant_weights_iterator
,
gguf_quant_weights_iterator_multi
,
)
)
from
vllm.transformers_utils.gguf_utils
import
detect_gguf_multimodal
from
vllm.transformers_utils.gguf_utils
import
detect_gguf_multimodal
from
vllm.utils.torch_utils
import
set_default_torch_dtype
from
vllm.utils.torch_utils
import
set_default_torch_dtype
...
@@ -74,6 +75,31 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -74,6 +75,31 @@ class GGUFModelLoader(BaseModelLoader):
"or <repo_id>:<quant_type>)"
"or <repo_id>:<quant_type>)"
)
)
@
staticmethod
def
_get_all_gguf_files
(
model_path
:
str
)
->
list
[
str
]:
"""Discover all GGUF shard files from a single shard path.
Supports variable-width shard indices by dynamically detecting
the padding from the original filename.
E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
``*-01-of-15.gguf`` → all 15 shards.
"""
match
=
re
.
search
(
r
"-(\d+)-of-(\d+)\.gguf$"
,
model_path
)
if
not
match
:
return
[
model_path
]
total
=
int
(
match
.
group
(
2
))
num_digits
=
len
(
match
.
group
(
1
))
prefix
=
model_path
[:
match
.
start
(
1
)]
suffix
=
model_path
[
match
.
end
(
2
)
:]
files
=
[]
for
i
in
range
(
1
,
total
+
1
):
shard_path
=
f
"
{
prefix
}{
i
:
0
{
num_digits
}
d
}
-of-
{
total
:
0
{
num_digits
}
d
}{
suffix
}
"
if
os
.
path
.
isfile
(
shard_path
):
files
.
append
(
shard_path
)
if
files
:
logger
.
info
(
"Discovered %d GGUF shard files"
,
len
(
files
))
return
files
if
files
else
[
model_path
]
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
GGUF uses this naming convention for their tensors from HF checkpoint:
...
@@ -145,6 +171,29 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -145,6 +171,29 @@ class GGUFModelLoader(BaseModelLoader):
r
"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
r
"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
)
)
)
)
if
model_type
==
"minimax_m2"
:
model_type
=
"minimax-m2"
# GGUF layer map assumes merged expert weights
# map them manually like deepseek2
for
idx
in
range
(
config
.
num_hidden_layers
):
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.exp_probs_b.bias"
]
=
(
f
"model.layers.
{
idx
}
.block_sparse_moe.e_score_correction_bias"
)
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_down_exps.weight"
]
=
(
f
"model.layers.
{
idx
}
.block_sparse_moe.experts.0.w2.weight"
)
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_gate_exps.weight"
]
=
(
f
"model.layers.
{
idx
}
.block_sparse_moe.experts.0.w1.weight"
)
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
(
f
"model.layers.
{
idx
}
.block_sparse_moe.experts.0.w3.weight"
)
sideload_params
.
append
(
re
.
compile
(
f
"model
\\
.layers
\\
.
{
idx
}
"
r
"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
)
)
arch
=
None
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
...
@@ -190,6 +239,13 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -190,6 +239,13 @@ class GGUFModelLoader(BaseModelLoader):
revert_hf_rename
(
name
):
tensor
for
name
,
tensor
in
state_dict
.
items
()
revert_hf_rename
(
name
):
tensor
for
name
,
tensor
in
state_dict
.
items
()
}
}
if
model_type
==
"minimax-m2"
and
not
hf_checkpoint_map
:
# Reverse HF convention: mlp -> block_sparse_moe
state_dict
=
{
name
.
replace
(
".mlp."
,
".block_sparse_moe."
):
tensor
for
name
,
tensor
in
state_dict
.
items
()
}
def
find_hf_name_in_tensor_map
(
hf_name
:
str
)
->
str
|
None
:
def
find_hf_name_in_tensor_map
(
hf_name
:
str
)
->
str
|
None
:
"""
"""
Map HuggingFace parameter name to GGUF tensor name.
Map HuggingFace parameter name to GGUF tensor name.
...
@@ -277,9 +333,10 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -277,9 +333,10 @@ class GGUFModelLoader(BaseModelLoader):
model_name_or_path
:
str
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
dict
[
str
,
str
],
gguf_to_hf_name_map
:
dict
[
str
,
str
],
)
->
dict
[
str
,
str
]:
)
->
dict
[
str
,
str
]:
weight_type_map
=
get_gguf_weight_type_map
(
gguf_files
=
self
.
_get_all_gguf_files
(
model_name_or_path
)
model_name_or_path
,
gguf_to_hf_name_map
weight_type_map
=
{}
)
for
f
in
gguf_files
:
weight_type_map
.
update
(
get_gguf_weight_type_map
(
f
,
gguf_to_hf_name_map
))
is_multimodal
=
hasattr
(
model_config
.
hf_config
,
"vision_config"
)
is_multimodal
=
hasattr
(
model_config
.
hf_config
,
"vision_config"
)
if
is_multimodal
:
if
is_multimodal
:
mmproj_file
=
detect_gguf_multimodal
(
model_name_or_path
)
mmproj_file
=
detect_gguf_multimodal
(
model_name_or_path
)
...
@@ -321,7 +378,15 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -321,7 +378,15 @@ class GGUFModelLoader(BaseModelLoader):
)
)
yield
from
gguf_quant_weights_iterator
(
mmproj_file
,
gguf_to_hf_name_map
)
yield
from
gguf_quant_weights_iterator
(
mmproj_file
,
gguf_to_hf_name_map
)
yield
from
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
gguf_files
=
self
.
_get_all_gguf_files
(
model_name_or_path
)
if
len
(
gguf_files
)
>
1
:
yield
from
gguf_quant_weights_iterator_multi
(
gguf_files
,
gguf_to_hf_name_map
)
else
:
yield
from
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
)
self
.
_prepare_weights
(
model_config
)
...
@@ -340,9 +405,11 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -340,9 +405,11 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path
=
self
.
_prepare_weights
(
model_config
)
local_model_path
=
self
.
_prepare_weights
(
model_config
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
gguf_files
=
self
.
_get_all_gguf_files
(
local_model_path
)
local_model_path
,
gguf_weights_map
all_extra_names
=
[]
):
for
f
in
gguf_files
:
all_extra_names
.
extend
(
get_gguf_extra_tensor_names
(
f
,
gguf_weights_map
))
if
"lm_head.weight"
in
all_extra_names
:
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
weight_type_map
=
self
.
_get_gguf_weight_type
(
weight_type_map
=
self
.
_get_gguf_weight_type
(
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
63babd17
...
@@ -1222,6 +1222,49 @@ def gguf_quant_weights_iterator(
...
@@ -1222,6 +1222,49 @@ def gguf_quant_weights_iterator(
yield
name
,
param
yield
name
,
param
def
gguf_quant_weights_iterator_multi
(
gguf_files
:
list
[
str
],
gguf_to_hf_name_map
:
dict
[
str
,
str
]
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""
Iterate over the quant weights across multiple GGUF shard files
and convert them to torch tensors.
Like gguf_quant_weights_iterator, we yield all weight types first
before yielding any weights data to avoid issues with packed layers
that have different quant types.
"""
readers
=
[
gguf
.
GGUFReader
(
f
)
for
f
in
gguf_files
]
# First pass: yield all weight types across all shards
for
reader
in
readers
:
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
not
in
(
"F32"
,
"BF16"
,
"F16"
):
weight_type_name
=
name
.
replace
(
"weight"
,
"qweight_type"
)
weight_type
=
torch
.
tensor
(
weight_type
)
yield
weight_type_name
,
weight_type
# Second pass: yield all weight data across all shards
for
reader
in
readers
:
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight
=
tensor
.
data
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
not
in
(
"F32"
,
"BF16"
,
"F16"
):
name
=
name
.
replace
(
"weight"
,
"qweight"
)
if
weight_type
.
name
==
"BF16"
and
tensor
.
data
.
dtype
==
np
.
uint8
:
weight
=
weight
.
view
(
np
.
uint16
)
if
reader
.
byte_order
==
"S"
:
weight
=
weight
.
byteswap
()
param
=
torch
.
tensor
(
weight
).
view
(
torch
.
bfloat16
)
else
:
param
=
torch
.
tensor
(
weight
)
yield
name
,
param
def
convert_pyslice_to_tensor
(
x
:
Any
)
->
torch
.
Tensor
:
def
convert_pyslice_to_tensor
(
x
:
Any
)
->
torch
.
Tensor
:
"""convert PySafeSlice object from safetensors to torch.Tensor
"""convert PySafeSlice object from safetensors to torch.Tensor
...
...
vllm/model_executor/models/minimax_m2.py
View file @
63babd17
...
@@ -331,7 +331,7 @@ class MiniMaxM2Model(nn.Module):
...
@@ -331,7 +331,7 @@ class MiniMaxM2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
None
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
)
else
:
else
:
...
@@ -518,7 +518,10 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -518,7 +518,10 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
)
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
None
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
)
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment