Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d3bdfd3a
Unverified
Commit
d3bdfd3a
authored
Aug 13, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 13, 2024
Browse files
[Misc] Update Fused MoE weight loading (#7334)
parent
fb377d7e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
264 additions
and
201 deletions
+264
-201
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+180
-136
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+80
-61
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+1
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+1
-1
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+1
-1
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
d3bdfd3a
...
...
@@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
return
fused_moe
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
gating_output
=
router_logits
,
renormalize
=
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
...
...
@@ -195,52 +201,83 @@ class FusedMoE(torch.nn.Module):
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
shard_id
:
str
,
expert_id
:
int
)
->
None
:
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
# Special case for fp8 scales.
if
getattr
(
param
,
"is_fp8_scale"
,
False
):
self
.
_load_fp8_scale
(
param
.
data
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
# Default is not transposed/input dim is dim 1
input_dim
=
getattr
(
param
,
"input_dim"
,
1
)
output_dim
=
getattr
(
param
,
"output_dim"
,
0
)
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
if
shard_id
==
"w2"
:
shard_dim
=
input_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif
shard_id
in
(
"w1"
,
"w3"
):
shard_dim
=
output_dim
shard_size
=
expert_data
.
shape
[
output_dim
]
//
2
offset
=
shard_size
*
tp_rank
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
offset
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
# w3, up_proj: Load into second logical weight of w13.
elif
shard_id
==
"w3"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
# w2, down_proj: Load into only logical weight of w2.
elif
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
0
:
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
2
:
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
1
:
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
shard_id
}
"
)
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
use_grouped_topk
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
)
return
topk_weights
,
topk_ids
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
...
...
@@ -248,14 +285,14 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
self
,
layer
=
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert
_group
=
self
.
num_expert
_group
,
topk
_group
=
self
.
topk
_group
)
topk
_group
=
self
.
topk
_group
,
num_expert
_group
=
self
.
num_expert
_group
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
@@ -267,35 +304,42 @@ class FusedMoE(torch.nn.Module):
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
str
]]:
return
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
(
"experts.w13_"
if
weight_name
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
."
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
[
(
"w1"
,
ckpt_gate_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
vllm/model_executor/layers/quantization/fp8.py
View file @
d3bdfd3a
...
...
@@ -290,23 +290,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_
weight_
scale"
,
w13_
weight_
scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
w2_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_
weight_
scale"
,
w2_
weight_
scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
...
@@ -315,20 +321,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
else
:
layer
.
a
13_scale
=
None
layer
.
a2
_scale
=
None
layer
.
w
13_
input_
scale
=
None
layer
.
w2_input
_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
...
...
@@ -341,16 +353,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
)
requires_grad
=
False
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
w13_weight
[
expert
,
:,
:],
layer
.
w13_
weight_
scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
w2_weight
[
expert
,
:,
:],
layer
.
w2_
weight_
scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
...
...
@@ -366,40 +378,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
a
13_scale
)
or
not
all_close_1d
(
layer
.
a2
_scale
)):
if
(
not
all_close_1d
(
layer
.
w
13_
input_
scale
)
or
not
all_close_1d
(
layer
.
w2_input
_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a
13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2
_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
layer
.
w
13_
input_
scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input
_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
assert
layer
.
w13_
weight_
scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
max_w13_scales
=
layer
.
w13_
weight_
scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
])
layer
.
w13_
weight_
scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
layer
.
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
def
apply
(
self
,
...
...
@@ -407,27 +420,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_moe
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d3bdfd3a
...
...
@@ -593,7 +593,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
vllm/model_executor/models/jamba.py
View file @
d3bdfd3a
...
...
@@ -930,7 +930,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
vllm/model_executor/models/mixtral.py
View file @
d3bdfd3a
...
...
@@ -455,7 +455,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
d3bdfd3a
...
...
@@ -492,7 +492,7 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment