Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c3f2628
Unverified
Commit
5c3f2628
authored
Jul 25, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 25, 2025
Browse files
[Quantization] Enable BNB support for more MoE models (#21370)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
7311f744
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
80 deletions
+93
-80
vllm/model_executor/models/dots1.py
vllm/model_executor/models/dots1.py
+78
-70
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+15
-10
No files found.
vllm/model_executor/models/dots1.py
View file @
5c3f2628
...
@@ -54,8 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -54,8 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -327,6 +327,7 @@ class Dots1DecoderLayer(nn.Module):
...
@@ -327,6 +327,7 @@ class Dots1DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
class
Dots1Model
(
nn
.
Module
):
class
Dots1Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
...
@@ -404,68 +405,12 @@ class Dots1Model(nn.Module):
...
@@ -404,68 +405,12 @@ class Dots1Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
@
support_torch_compile
return
FusedMoE
.
make_expert_params_mapping
(
class
Dots1ForCausalLM
(
nn
.
Module
,
SupportsPP
):
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
ckpt_up_proj_name
=
"up_proj"
,
super
().
__init__
()
num_experts
=
self
.
config
.
n_routed_experts
)
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Dots1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
...
@@ -477,14 +422,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
...
@@ -477,14 +422,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -534,3 +474,71 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
...
@@ -534,3 +474,71 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
return
loaded_params
return
loaded_params
class
Dots1ForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Dots1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/glm4_moe.py
View file @
5c3f2628
...
@@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -461,6 +461,15 @@ class Glm4MoeModel(nn.Module):
...
@@ -461,6 +461,15 @@ class Glm4MoeModel(nn.Module):
device
=
device
),
device
=
device
),
})
})
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -472,16 +481,9 @@ class Glm4MoeModel(nn.Module):
...
@@ -472,16 +481,9 @@ class Glm4MoeModel(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
spec_layer
is
not
None
:
if
spec_layer
is
not
None
:
...
@@ -570,7 +572,7 @@ class Glm4MoeModel(nn.Module):
...
@@ -570,7 +572,7 @@ class Glm4MoeModel(nn.Module):
return
loaded_params
return
loaded_params
class
Glm4MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
Glm4MoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -677,6 +679,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -677,6 +679,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP):
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
def
get_spec_layer_idx_from_weight_name
(
config
:
PretrainedConfig
,
def
get_spec_layer_idx_from_weight_name
(
config
:
PretrainedConfig
,
weight_name
:
str
)
->
Optional
[
int
]:
weight_name
:
str
)
->
Optional
[
int
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment