Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb6af8bc
Unverified
Commit
fb6af8bc
authored
Jul 13, 2024
by
Robert Shaw
Committed by
GitHub
Jul 13, 2024
Browse files
[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)
parent
eeceadae
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
223 additions
and
137 deletions
+223
-137
.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
...ldkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-large.txt
.buildkite/lm-eval-harness/configs/models-large.txt
+1
-0
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+1
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+26
-10
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+78
-15
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+8
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+69
-73
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+7
-25
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+22
-11
No files found.
.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
0 → 100644
View file @
fb6af8bc
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
model_name
:
"
deepseek-ai/DeepSeek-V2-Lite-Chat"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.671
-
name
:
"
exact_match,flexible-extract"
value
:
0.664
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-large.txt
View file @
fb6af8bc
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
View file @
fb6af8bc
...
...
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval
--model
vllm
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,add_bos_token
=
true
,distributed_executor_backend
=
"ray"
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,add_bos_token
=
true
,distributed_executor_backend
=
"ray"
,trust_remote_code
=
true
\
--tasks
gsm8k
--num_fewshot
$FEWSHOT
--limit
$LIMIT
\
--batch_size
$BATCH_SIZE
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
fb6af8bc
...
...
@@ -394,14 +394,16 @@ def fused_topk(
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
...
...
@@ -557,6 +559,9 @@ def fused_moe(
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -579,6 +584,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...
...
@@ -592,8 +601,15 @@ def fused_moe(
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
fb6af8bc
from
abc
import
abstractmethod
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
...
...
@@ -29,7 +29,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
)
->
torch
.
Tensor
:
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -63,7 +66,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
)
->
torch
.
Tensor
:
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -71,7 +77,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
)
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
class
FusedMoE
(
torch
.
nn
.
Module
):
...
...
@@ -104,6 +113,9 @@ class FusedMoE(torch.nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -119,6 +131,11 @@ class FusedMoE(torch.nn.Module):
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
...
@@ -140,9 +157,8 @@ class FusedMoE(torch.nn.Module):
shard_id
:
int
,
expert_id
:
int
):
param_data
=
param
.
data
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if
"input_scale"
in
weight_name
or
"w2.weight_scale"
in
weight_name
:
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
...
...
@@ -150,14 +166,21 @@ class FusedMoE(torch.nn.Module):
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
# Weight scales
elif
"weight_scale"
in
weight_name
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert
"w1"
in
weight_name
or
"w3"
in
weight_name
shard_id
=
0
if
"w1"
in
weight_name
else
1
param_data
[
expert_id
][
shard_id
]
=
loaded_weight
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
...
...
@@ -188,10 +211,50 @@ class FusedMoE(torch.nn.Module):
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
)
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
return
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
vllm/model_executor/layers/quantization/fp8.py
View file @
fb6af8bc
...
...
@@ -377,7 +377,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
)
->
torch
.
Tensor
:
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -390,7 +393,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
)
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
fb6af8bc
...
...
@@ -29,11 +29,10 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
f
used
_experts
,
grouped_topk
from
vllm.model_executor.layers.fused_moe
import
F
used
MoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
...
...
@@ -91,32 +90,34 @@ class DeepseekV2MoE(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
if
self
.
tp_size
>
self
.
n_routed_experts
:
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
if
self
.
tp_size
>
config
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
experts
=
nn
.
ModuleList
([
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
self
.
pack_params
()
f
"the number of experts
{
config
.
n_routed_experts
}
."
)
if
config
.
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
...
...
@@ -128,50 +129,21 @@ class DeepseekV2MoE(nn.Module):
reduce_results
=
False
,
)
def
pack_params
(
self
):
w1
=
[]
w2
=
[]
for
expert
in
self
.
experts
:
w1
.
append
(
expert
.
gate_up_proj
.
weight
)
w2
.
append
(
expert
.
down_proj
.
weight
)
self
.
w1
=
torch
.
_utils
.
_flatten_dense_tensors
(
w1
)
w1s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w1
,
w1
)
for
data
,
param
in
zip
(
w1s
,
w1
):
param
.
data
=
data
self
.
w1
=
self
.
w1
.
view
(
len
(
w1
),
*
w1s
[
0
].
shape
)
self
.
w2
=
torch
.
_utils
.
_flatten_dense_tensors
(
w2
)
w2s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w2
,
w2
)
for
data
,
param
in
zip
(
w2s
,
w2
):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
num_expert_group
=
self
.
config
.
n_group
,
topk_group
=
self
.
config
.
topk_group
)
final_hidden_states
=
fused_experts
(
hidden_states
,
self
.
w1
,
self
.
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
)
*
self
.
routed_scaling_factor
if
self
.
config
.
n_shared_experts
is
not
None
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
@@ -504,34 +476,58 @@ class DeepseekV2ForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/mixtral.py
View file @
fb6af8bc
...
...
@@ -372,31 +372,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"experts.w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
shard_id
,
weight_name
in
enumerate
([
"w1"
,
"w2"
,
"w3"
])
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"experts.w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
shard_id
,
weight_name
in
enumerate
([
"w1"
,
"w2"
,
"w3"
])
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"experts.a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"experts.a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
shard_id
,
weight_name
in
enumerate
([
"w1"
,
"w2"
,
"w3"
])
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
fb6af8bc
...
...
@@ -50,6 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
class
Qwen2MoeMLP
(
nn
.
Module
):
...
...
@@ -406,15 +407,13 @@ class Qwen2MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
shard_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"down_proj"
,
"up_proj"
])
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
...
...
@@ -461,8 +460,20 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
".kv_scale"
,
".attn.kv_scale"
)
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint "
f
"(e.g.
{
name
}
), but not found the expected "
f
"name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment