Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a99b9f7d
Unverified
Commit
a99b9f7d
authored
Jul 14, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 14, 2025
Browse files
[Quantization] add BNB for MixtralForCausalLM (#20893)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
c488b928
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
128 additions
and
20 deletions
+128
-20
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-1
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+102
-3
vllm/model_executor/models/granitemoeshared.py
vllm/model_executor/models/granitemoeshared.py
+2
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+13
-8
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+2
-1
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+2
-1
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+1
-3
No files found.
vllm/model_executor/model_loader/utils.py
View file @
a99b9f7d
...
...
@@ -227,7 +227,12 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
,
"gptq_marlin"
,
"awq_marlin"
,
"quark"
"fp8"
,
"compressed-tensors"
,
"gptq_marlin"
,
"awq_marlin"
,
"quark"
,
"bitsandbytes"
,
]
vllm_supported_archs
=
ModelRegistry
.
get_supported_archs
()
...
...
vllm/model_executor/models/granitemoe.py
View file @
a99b9f7d
...
...
@@ -45,12 +45,14 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
)
class
GraniteMoeMoE
(
nn
.
Module
):
...
...
@@ -307,6 +309,103 @@ class GraniteMoeModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
_load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""
This function is copied from `MixtralModel.load_weights`, mainly to
decouple from mixtral, avoiding impact on support like BNB
quantization.
"""
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
new_weights
=
{}
...
...
@@ -339,7 +438,7 @@ class GraniteMoeModel(nn.Module):
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralModel
.
load_weights
(
self
,
new_weights
.
items
())
return
self
.
_
load_weights
(
new_weights
.
items
())
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
...
...
vllm/model_executor/models/granitemoeshared.py
View file @
a99b9f7d
...
...
@@ -27,8 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.
import
mixtral
from
.granitemoe
import
GraniteMoeAttention
,
GraniteMoeMoE
from
.granitemoe
import
GraniteMoeAttention
,
GraniteMoeModel
,
GraniteMoeMoE
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
...
...
@@ -242,7 +241,7 @@ class GraniteMoeSharedModel(nn.Module):
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
Mixtral
Model
.
load_weights
(
self
,
new_weights
.
items
())
return
GraniteMoe
Model
.
_
load_weights
(
self
,
new_weights
.
items
())
class
GraniteMoeSharedForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
...
...
vllm/model_executor/models/mixtral.py
View file @
a99b9f7d
...
...
@@ -317,6 +317,15 @@ class MixtralModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -326,16 +335,9 @@ class MixtralModel(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
...
...
@@ -486,3 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/olmoe.py
View file @
a99b9f7d
...
...
@@ -352,6 +352,7 @@ class OlmoeModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
...
...
@@ -380,7 +381,7 @@ class OlmoeModel(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
self
.
get_expert
_mapping
()
:
for
mapping
in
expert_params
_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
a99b9f7d
...
...
@@ -413,6 +413,7 @@ class Qwen2MoeModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
...
...
@@ -442,7 +443,7 @@ class Qwen2MoeModel(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
self
.
get_expert
_mapping
()
:
for
mapping
in
expert_params
_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
a99b9f7d
...
...
@@ -400,11 +400,9 @@ class Qwen3MoeModel(nn.Module):
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
self
.
get_expert_mapping
()
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment