Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8020e98c
Unverified
Commit
8020e98c
authored
Jul 11, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 11, 2025
Browse files
[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
762be26a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
561 additions
and
88 deletions
+561
-88
tests/models/quantization/test_bitsandbytes.py
tests/models/quantization/test_bitsandbytes.py
+39
-6
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+32
-4
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+223
-9
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+193
-45
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+24
-9
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+11
-0
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+26
-9
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+13
-6
No files found.
tests/models/quantization/test_bitsandbytes.py
View file @
8020e98c
...
@@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
...
@@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
...utils
import
compare_two_settings
,
multi_gpu_test
from
...utils
import
compare_two_settings
,
multi_gpu_test
from
..utils
import
check_embeddings_close
from
..utils
import
check_embeddings_close
,
check_logprobs_close
models_4bit_to_test
=
[
models_4bit_to_test
=
[
(
"facebook/opt-125m"
,
"quantize opt model inflight"
),
(
"facebook/opt-125m"
,
"quantize opt model inflight"
),
...
@@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
...
@@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
(
"intfloat/e5-mistral-7b-instruct"
,
"quantize embedding model inflight"
),
(
"intfloat/e5-mistral-7b-instruct"
,
"quantize embedding model inflight"
),
]
]
models_4bit_to_moe_test
=
[
(
"allenai/OLMoE-1B-7B-0125-Instruct"
,
"quantize moe model inflight"
),
]
models_pre_qaunt_4bit_to_test
=
[
models_pre_qaunt_4bit_to_test
=
[
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
'read pre-quantized 4-bit FP4 model'
),
'read pre-quantized 4-bit FP4 model'
),
...
@@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
...
@@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings
(
model_name
,
common_args
,
pp_args
)
compare_two_settings
(
model_name
,
common_args
,
pp_args
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_4bit_to_moe_test
)
def
test_4bit_bnb_moe_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
hf_model_kwargs
=
dict
(
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
,
bnb_4bit_quant_type
=
"nf4"
,
bnb_4bit_use_double_quant
=
True
,
))
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
enforce_eager
=
False
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
=
32
,
num_logprobs
=
5
)
with
hf_runner
(
model_name
,
model_kwargs
=
hf_model_kwargs
)
as
llm
:
transformers_outputs
=
llm
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
=
32
,
num_logprobs
=
5
)
check_logprobs_close
(
outputs_0_lst
=
transformers_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"transformers"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
...
@@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
...
@@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
model_name
,
model_name
,
pre_quant
=
False
,
pre_quant
=
False
,
hf_model_kwargs
=
None
,
hf_model_kwargs
=
None
,
vllm_tp_size
=
1
):
vllm_tp_size
=
1
,
max_tokens
=
8
):
# NOTE: run vLLM first, as it requires a clean process
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
# when using distributed inference
...
@@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
...
@@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
quantization
=
None
if
pre_quant
else
'bitsandbytes'
,
quantization
=
None
if
pre_quant
else
'bitsandbytes'
,
tensor_parallel_size
=
vllm_tp_size
,
tensor_parallel_size
=
vllm_tp_size
,
enforce_eager
=
False
)
as
llm
:
enforce_eager
=
False
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
max_tokens
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
# Clean up the GPU memory for the next test
# Clean up the GPU memory for the next test
...
@@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
...
@@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
# Run with HF runner
# Run with HF runner
with
hf_runner
(
model_name
,
model_kwargs
=
hf_model_kwargs
)
as
llm
:
with
hf_runner
(
model_name
,
model_kwargs
=
hf_model_kwargs
)
as
llm
:
hf_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
hf_outputs
=
llm
.
generate_greedy
(
prompts
,
max_tokens
)
hf_logs
=
log_generated_texts
(
prompts
,
hf_outputs
,
"HfRunner"
)
hf_logs
=
log_generated_texts
(
prompts
,
hf_outputs
,
"HfRunner"
)
# Clean up the GPU memory for the next test
# Clean up the GPU memory for the next test
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Compare the generated strings
# Compare the generated strings
for
hf_log
,
vllm_log
in
zip
(
hf_logs
,
vllm_logs
):
for
hf_log
,
vllm_log
in
zip
(
hf_logs
,
vllm_logs
):
hf_str
=
hf_log
[
"generated_text"
]
hf_str
=
hf_log
[
"generated_text"
]
vllm_str
=
vllm_log
[
"generated_text"
]
vllm_str
=
vllm_log
[
"generated_text"
]
prompt
=
hf_log
[
"prompt"
]
prompt
=
hf_log
[
"prompt"
]
assert
hf_str
==
vllm_str
,
(
f
"Model:
{
model_name
}
"
assert
hf_str
==
vllm_str
,
(
f
"Model:
{
model_name
}
"
f
"Mismatch between HF and vLLM outputs:
\n
"
f
"Mismatch between HF and vLLM outputs:
\n
"
f
"Prompt:
{
prompt
}
\n
"
f
"Prompt:
{
prompt
}
\n
"
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
8020e98c
...
@@ -883,13 +883,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -883,13 +883,20 @@ class FusedMoE(torch.nn.Module):
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
tp_rank
=
tp_rank
)
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
def
_load_w13
(
self
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
):
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
load_full
:
bool
=
False
):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
if
not
load_full
:
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
shard_size
)
# Narrow parameter and load.
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w1, gate_proj: Load into first logical weight of w13.
...
@@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module):
...
@@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module):
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
return
True
if
return_success
else
None
return
True
if
return_success
else
None
# Case for BitsAndBytes
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
if
use_bitsandbytes_4bit
:
shard_dim
=
0
expert_data
=
param
.
data
[
expert_id
]
if
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
# BNB inflight quantization has already sharded the weights
full_load
=
True
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
self
.
tp_rank
,
load_full
=
full_load
,
)
return
True
if
return_success
else
None
# is_transposed: if the dim to shard the weight
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
# should be whatever dimension intermediate_size_per_partition is
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
8020e98c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
set_weight_attrs
)
set_weight_attrs
)
...
@@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules
=
llm_int8_skip_modules
,
llm_int8_skip_modules
=
llm_int8_skip_modules
,
llm_int8_threshold
=
llm_int8_threshold
)
llm_int8_threshold
=
llm_int8_threshold
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"LinearMethodBase"
,
"BitsAndBytesMoEMethod"
]]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_bnb
(
prefix
,
self
.
llm_int8_skip_modules
):
if
is_layer_skipped_bnb
(
prefix
,
self
.
llm_int8_skip_modules
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
BitsAndBytesLinearMethod
(
self
)
return
BitsAndBytesLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
BitsAndBytesMoEMethod
(
self
)
return
None
return
None
...
@@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
...
@@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
return
substr_check
or
prefix_check
return
substr_check
or
prefix_check
def
calculate_quant_ratio
(
dtype
):
if
dtype
.
is_floating_point
:
return
torch
.
finfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
else
:
return
torch
.
iinfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
"""Linear method for BitsAndBytes.
"""Linear method for BitsAndBytes.
...
@@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
...
@@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
**
extra_weight_attrs
):
**
extra_weight_attrs
):
from
bitsandbytes.nn
import
Int8Params
from
bitsandbytes.nn
import
Int8Params
def
calculate_quant_ratio
(
dtype
):
if
dtype
.
is_floating_point
:
return
torch
.
finfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
else
:
return
torch
.
iinfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
def
create_qweight_for_8bit
():
def
create_qweight_for_8bit
():
qweight
=
Int8Params
(
qweight
=
Int8Params
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
...
@@ -394,3 +401,210 @@ try:
...
@@ -394,3 +401,210 @@ try:
except
AttributeError
as
error
:
except
AttributeError
as
error
:
raise
error
raise
error
class
BitsAndBytesMoEMethod
(
FusedMoEMethodBase
):
"""MoE method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.45.3"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer."
)
from
err
self
.
topk_indices_dtype
=
None
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
quant_config
.
load_in_8bit
:
call_fun
=
self
.
_create_weights_8bit
else
:
call_fun
=
self
.
_create_weights_4bit
call_fun
(
layer
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
params_dtype
,
**
extra_weight_attrs
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
if
self
.
quant_config
.
load_in_8bit
:
w13
,
w2
=
self
.
_apply_8bit_dequant
(
layer
)
else
:
w13
,
w2
=
self
.
_apply_4bit_dequnt
(
layer
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
w13
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
)
def
_create_weights_4bit
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
quant_ratio
=
calculate_quant_ratio
(
params_dtype
)
# Fused gate_up_proj (column parallel)
w13_total_size
=
(
hidden_size
*
2
*
intermediate_size_per_partition
)
//
quant_ratio
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_total_size
,
1
,
dtype
=
torch
.
uint8
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_qweight
,
{
"num_experts"
:
num_experts
,
"input_dim"
:
hidden_size
,
"output_dim"
:
2
*
intermediate_size_per_partition
,
"experts_shape"
:
(
num_experts
,
intermediate_size_per_partition
*
2
,
hidden_size
,
),
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes_4bit"
:
True
,
},
)
# down_proj (row parallel)
w2_total_size
=
(
hidden_size
*
intermediate_size_per_partition
)
//
quant_ratio
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w2_total_size
,
1
,
dtype
=
torch
.
uint8
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_qweight
,
{
"num_experts"
:
num_experts
,
"input_dim"
:
intermediate_size_per_partition
,
"output_dim"
:
hidden_size
,
"experts_shape"
:
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
),
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes_4bit"
:
True
,
},
)
layer
.
register_parameter
(
"w2_weight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
def
_create_weights_8bit
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
_apply_4bit_dequnt
(
self
,
layer
:
torch
.
nn
.
Module
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
bitsandbytes.functional
import
dequantize_4bit
w13
=
dequantize_4bit
(
layer
.
w13_weight
.
reshape
(
-
1
,
1
),
layer
.
w13_weight
.
bnb_quant_state
,
)
w2
=
dequantize_4bit
(
layer
.
w2_weight
.
reshape
(
-
1
,
1
),
layer
.
w2_weight
.
bnb_quant_state
,
)
w13
=
w13
.
reshape
(
layer
.
w13_weight
.
experts_shape
)
w2
=
w2
.
reshape
(
layer
.
w2_weight
.
experts_shape
)
return
w13
,
w2
def
_apply_8bit_dequant
(
self
,
layer
:
torch
.
nn
.
Module
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
8020e98c
...
@@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
# yapf: enable
# yapf: enable
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
# weights with same last name.
self
.
target_modules
.
append
(
name
)
self
.
target_modules
.
append
(
name
)
elif
(
isinstance
(
module
,
FusedMoE
)
and
hasattr
(
module
.
quant_method
,
"quant_config"
)):
if
not
hasattr
(
model
,
"get_expert_mapping"
):
raise
AttributeError
(
f
"MoE Model
{
type
(
model
).
__name__
}
does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method."
)
# TODO: support FusedMoE with prequant and 8bit.
if
self
.
pre_quant
:
raise
ValueError
(
"Prequant BitsAndBytes models with FusedMoE is not "
"supported yet."
)
if
self
.
load_8bit
:
raise
ValueError
(
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet."
)
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping
=
model
.
get_expert_mapping
()
for
exp
in
expert_mapping
:
weight_name
=
exp
[
1
]
rep_name
=
name
.
replace
(
"experts"
,
""
)
+
weight_name
.
removesuffix
(
"."
)
self
.
target_modules
.
append
(
rep_name
)
assert
(
self
.
target_modules
assert
(
self
.
target_modules
),
"v
llm
currently does not support BNB quantization for"
),
"v
LLM
currently does not support BNB quantization for"
f
"
{
type
(
model
).
__name__
}
"
f
"
{
type
(
model
).
__name__
}
"
def
_classify_module_sharding
(
self
,
model
:
nn
.
Module
):
def
_classify_module_sharding
(
self
,
model
:
nn
.
Module
):
...
@@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# dimension (dim=-1)
# dimension (dim=-1)
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
self
.
column_sharded_weights_modules
.
append
(
name
)
self
.
column_sharded_weights_modules
.
append
(
name
)
elif
isinstance
(
module
,
FusedMoE
):
expert_mapping
=
model
.
get_expert_mapping
()
for
exp
in
expert_mapping
:
if
exp
[
-
1
]
==
"w2"
:
weight_name
=
exp
[
1
]
rep_name
=
name
.
replace
(
"experts"
,
""
)
+
weight_name
.
removesuffix
(
"."
)
self
.
column_sharded_weights_modules
.
append
(
rep_name
)
def
_verify_model_compatibility
(
self
,
model
:
nn
.
Module
,
def
_verify_model_compatibility
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
model_config
:
ModelConfig
)
->
None
:
...
@@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self
.
_get_bnb_target_modules
(
model
)
self
.
_get_bnb_target_modules
(
model
)
self
.
_classify_module_sharding
(
model
)
self
.
_classify_module_sharding
(
model
)
def
load_weights
(
self
,
model
:
nn
.
Module
,
def
_dequantize_dq
(
self
,
quant_states
:
Any
):
model_config
:
ModelConfig
)
->
None
:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from
bitsandbytes.functional
import
QuantState
,
dequantize_blockwise
self
.
_verify_model_compatibility
(
model
,
model_config
)
def
_dequantize_single_state
(
quant_state
):
self
.
_initialize_loader_state
(
model
,
model_config
)
"""Helper function to dequantize a single QuantState object."""
if
not
(
isinstance
(
quant_state
,
QuantState
)
and
quant_state
.
nested
):
return
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
"May take a while ..."
)
absmax
=
dequantize_blockwise
(
quant_state
.
absmax
,
qweight_iterator
,
quant_state_dict
=
(
quant_state
.
state2
)
self
.
_get_quantized_weights_iterator
(
absmax
+=
quant_state
.
offset
model_config
.
model
,
model_config
.
revision
,
))
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
qweight_iterator
)
# Some models may have weights loading tracker unimplemented.
if
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
param_dict
=
dict
(
model
.
named_parameters
())
# Ensure float32 dtype
if
absmax
.
dtype
!=
torch
.
float32
:
absmax
=
absmax
.
float
()
quant_state
.
absmax
=
absmax
quant_state
.
nested
=
False
quant_state
.
offset
=
None
quant_state
.
state2
=
None
if
isinstance
(
quant_states
,
dict
):
for
quant_state
in
quant_states
.
values
():
_dequantize_single_state
(
quant_state
)
else
:
_dequantize_single_state
(
quant_states
)
return
quant_states
def
_fuse_moe_quant_states
(
self
,
model
:
nn
.
Module
,
quant_states_dict
:
dict
)
->
dict
:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
from
bitsandbytes.functional
import
QuantState
if
not
hasattr
(
model
,
"get_expert_mapping"
):
return
dict
()
expert_mapping
=
model
.
get_expert_mapping
()
expert_qs_dict
=
{}
for
name
,
module
in
model
.
named_modules
():
if
not
isinstance
(
module
,
FusedMoE
):
continue
w1_states_lst
=
[]
w2_states_lst
=
[]
w3_states_lst
=
[]
for
exp
in
expert_mapping
:
shard_id
=
exp
[
-
1
]
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
layer_prefix
=
name
.
split
(
"experts"
)[
0
]
weight_qual_name
=
layer_prefix
+
exp
[
1
]
+
"weight"
quant_state
=
self
.
_dequantize_dq
(
quant_states_dict
[
weight_qual_name
])
if
shard_id
==
"w1"
:
w1_states_lst
.
append
(
quant_state
)
elif
shard_id
==
"w2"
:
w2_states_lst
.
append
(
quant_state
)
else
:
w3_states_lst
.
append
(
quant_state
)
del
quant_states_dict
[
weight_qual_name
]
assert
(
len
(
w1_states_lst
)
==
len
(
w2_states_lst
)
==
len
(
w3_states_lst
))
w13_absmax_lst
=
[]
w2_absmax_lst
=
[]
w13_total_dim0
=
0
w2_total_dim0
=
0
for
w1_qs
,
w2_qs
,
w3_qs
in
zip
(
w1_states_lst
,
w2_states_lst
,
w3_states_lst
):
assert
w1_qs
.
shape
==
w3_qs
.
shape
assert
w1_qs
.
blocksize
==
w2_qs
.
blocksize
==
w3_qs
.
blocksize
assert
w1_qs
.
dtype
==
w2_qs
.
dtype
==
w3_qs
.
dtype
# w1 and w3 are interleaved in storage
w13_absmax_lst
.
append
(
w1_qs
.
absmax
)
w13_absmax_lst
.
append
(
w3_qs
.
absmax
)
w2_absmax_lst
.
append
(
w2_qs
.
absmax
)
w13_total_dim0
+=
w1_qs
.
shape
[
0
]
+
w3_qs
.
shape
[
0
]
w2_total_dim0
+=
w2_qs
.
shape
[
0
]
w13_absmax
=
torch
.
cat
(
w13_absmax_lst
)
w2_absmax
=
torch
.
cat
(
w2_absmax_lst
)
# Create fused quantization state for w13.
w13_qs
=
QuantState
(
absmax
=
w13_absmax
,
shape
=
(
w13_total_dim0
,
w1_states_lst
[
0
].
shape
[
1
]),
code
=
w1_states_lst
[
0
].
code
,
blocksize
=
w1_states_lst
[
0
].
blocksize
,
quant_type
=
"nf4"
,
dtype
=
w1_states_lst
[
0
].
dtype
,
)
# Create fused quantization state for w2.
w2_qs
=
QuantState
(
absmax
=
w2_absmax
,
shape
=
(
w2_total_dim0
,
w2_states_lst
[
0
].
shape
[
1
]),
code
=
w2_states_lst
[
0
].
code
,
blocksize
=
w2_states_lst
[
0
].
blocksize
,
quant_type
=
"nf4"
,
dtype
=
w2_states_lst
[
0
].
dtype
,
)
# The weight suffixes .w13_weight and .w2_weight are consistent
# with the param in BitsAndBytesMoEMethod.
w13_weight_name
=
name
+
".w13_weight"
w2_weight_name
=
name
+
".w2_weight"
expert_qs_dict
[
w13_weight_name
]
=
w13_qs
expert_qs_dict
[
w2_weight_name
]
=
w2_qs
return
expert_qs_dict
def
_stack_quantization_states
(
self
,
model
:
nn
.
Module
,
quant_state_dict
:
dict
)
->
dict
[
str
,
dict
[
int
,
Any
]]:
stacked_quant_state_dict
:
dict
[
str
,
dict
[
int
,
Any
]]
=
{}
stacked_quant_state_dict
:
dict
[
str
,
dict
[
int
,
Any
]]
=
{}
# TODO: Change this lazy import to normal import
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
# after the checks are updated to run on a new version
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
param_dict
=
dict
(
model
.
named_parameters
())
for
quant_param_name
in
quant_state_dict
:
for
quant_param_name
in
quant_state_dict
:
if
is_pp_missing_parameter
(
quant_param_name
,
model
):
if
is_pp_missing_parameter
(
quant_param_name
,
model
):
continue
continue
...
@@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
stacked_quant_state_dict
[
quant_param_name
][
shard_index
]
=
(
stacked_quant_state_dict
[
quant_param_name
][
shard_index
]
=
(
quant_state_dict
[
non_stacked_param_name
])
quant_state_dict
[
non_stacked_param_name
])
return
stacked_quant_state_dict
def
_bind_quant_states_to_params
(
self
,
model
:
nn
.
Module
,
stacked_quant_state_dict
:
dict
)
->
None
:
# save quant_states and offsets as the attributes of the parameters
# save quant_states and offsets as the attributes of the parameters
param_dict
=
dict
(
model
.
named_parameters
())
for
param_name
,
param
in
param_dict
.
items
():
for
param_name
,
param
in
param_dict
.
items
():
if
param_name
in
stacked_quant_state_dict
:
if
param_name
in
stacked_quant_state_dict
:
quant_states
=
stacked_quant_state_dict
[
param_name
]
quant_states
=
stacked_quant_state_dict
[
param_name
]
# Dequantize double quantized values during weight loading.
# Dequantize double quantized values during weight loading.
dequantize_dq
(
quant_states
)
self
.
_
dequantize_dq
(
quant_states
)
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
if
not
isinstance
(
quant_states
,
dict
):
continue
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
if
pack_ratio
==
-
1
:
if
pack_ratio
==
-
1
:
...
@@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
self
.
load_8bit
:
if
self
.
load_8bit
:
set_weight_attrs
(
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
self
.
_verify_model_compatibility
(
model
,
model_config
)
self
.
_initialize_loader_state
(
model
,
model_config
)
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
"May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
))
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
qweight_iterator
)
# Some models may have weights loading tracker unimplemented.
if
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
expert_quant_state_dict
=
self
.
_fuse_moe_quant_states
(
model
,
quant_state_dict
)
stacked_quant_state_dict
=
self
.
_stack_quantization_states
(
model
,
quant_state_dict
)
stacked_quant_state_dict
=
{
**
expert_quant_state_dict
,
**
stacked_quant_state_dict
}
self
.
_bind_quant_states_to_params
(
model
,
stacked_quant_state_dict
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
dequantize_dq
(
quant_states
:
dict
)
->
None
:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from
bitsandbytes.functional
import
QuantState
,
dequantize_blockwise
for
_
,
quant_state
in
quant_states
.
items
():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if
isinstance
(
quant_state
,
QuantState
)
and
quant_state
.
nested
:
absmax
=
dequantize_blockwise
(
quant_state
.
absmax
,
quant_state
.
state2
)
absmax
+=
quant_state
.
offset
if
absmax
.
dtype
!=
torch
.
float32
:
absmax
=
absmax
.
float
()
quant_state
.
absmax
=
absmax
quant_state
.
nested
=
False
quant_state
.
offset
=
None
quant_state
.
state2
=
None
vllm/model_executor/models/olmoe.py
View file @
8020e98c
...
@@ -330,6 +330,15 @@ class OlmoeModel(nn.Module):
...
@@ -330,6 +330,15 @@ class OlmoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -341,14 +350,6 @@ class OlmoeModel(nn.Module):
...
@@ -341,14 +350,6 @@ class OlmoeModel(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
@@ -379,7 +380,7 @@ class OlmoeModel(nn.Module):
...
@@ -379,7 +380,7 @@ class OlmoeModel(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
for
mapping
in
expert_params
_mapping
:
for
mapping
in
self
.
get_expert
_mapping
()
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
@@ -425,6 +426,17 @@ class OlmoeModel(nn.Module):
...
@@ -425,6 +426,17 @@ class OlmoeModel(nn.Module):
class
OlmoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
OlmoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/phimoe.py
View file @
8020e98c
...
@@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module):
...
@@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/qwen2_moe.py
View file @
8020e98c
...
@@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
...
@@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
...
@@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
@@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
...
@@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
for
mapping
in
expert_params
_mapping
:
for
mapping
in
self
.
get_expert
_mapping
()
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
"layers.13.mlp.experts.w2_weight"
in
name
:
pass
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
...
@@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
...
@@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
class
Qwen2MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
Qwen2MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/qwen3_moe.py
View file @
8020e98c
...
@@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module):
...
@@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
self
.
get_expert_mapping
()
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
@@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment