Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed246203
Unverified
Commit
ed246203
authored
Apr 28, 2025
by
Charlie Fu
Committed by
GitHub
Apr 28, 2025
Browse files
[Bugfix] Fix moe weight losing all extra attrs after `process_weights_after_loading`. (#16854)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
cc5befbc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
11 deletions
+6
-11
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-10
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+1
-1
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
ed246203
...
...
@@ -113,12 +113,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
)
# Padding the weight for better performance on ROCm
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
# Lazy import to avoid importing triton.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
...
...
@@ -127,10 +124,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
ed246203
...
...
@@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_mi250_mi300
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
not
on_mi250_mi300
(
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi250_mi300
(
)
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment