Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d260f799
Unverified
Commit
d260f799
authored
May 27, 2025
by
vllmellm
Committed by
GitHub
May 26, 2025
Browse files
[FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (#18271)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
b50602d5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
133 additions
and
317 deletions
+133
-317
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-4
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+123
-278
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+6
-32
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
d260f799
...
...
@@ -419,10 +419,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffle_weights
)
if
self
.
rocm_aiter_moe_enabled
:
# use 2stage ck moe layout
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
,
layout
=
(
32
,
32
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
d260f799
# SPDX-License-Identifier: Apache-2.0
from
enum
import
IntEnum
from
functools
import
cache
from
typing
import
Optional
...
...
@@ -9,6 +10,28 @@ from vllm.platforms import current_platform
from
vllm.utils
import
direct_register_custom_op
class
QuantMethod
(
IntEnum
):
# This allows interfacing with AITER QuantType Enum
# without importing the QuantType from AITER globally.
# Note that these quantization methods are
# supported in AITER package. However,
# not all are used in this module.
NO
=
0
# a16w16
PER_TENSOR
=
1
# w8a8 (pre_Tensor)
PER_TOKEN
=
2
# w8a8/w8a4 (per_Token)
BLOCK_1X128
=
3
# block quantized w8a8 (per_1x128)
BLOCK_128x128
=
4
# block quantized w8a8 (per_128x128)
class
ActivationMethod
(
IntEnum
):
# This allows interfacing with AITER ActivationType enum
# without importing the ActivationType enum from AITER globally.
SILU
=
0
GELU
=
1
@
cache
def
is_rocm_aiter_moe_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
...
...
@@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl(
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_
str
:
str
=
"si
lu
"
)
->
torch
.
Tensor
:
activation_
method
:
int
=
ActivationMethod
.
SILU
.
va
lu
e
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe_tkw1
activation
=
\
ActivationType
.
Gelu
if
activation_str
==
"gelu"
else
ActivationType
.
Silu
activation
=
ActivationType
(
activation_method
)
return
asm_moe_tkw1
(
hidden_states
,
w1
,
...
...
@@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_str
:
str
=
"silu"
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_impl
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
list
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
from
aiter
import
fmoe_fp8_blockscale_g1u1
from
aiter.fused_moe_bf16_asm
import
moe_sorting_ck
topk
=
topk_ids
.
shape
[
1
]
model_dim
=
w1
.
shape
[
-
1
]
local_E
=
E
=
w1
.
shape
[
0
]
if
expert_mask
is
not
None
:
E
=
expert_mask
.
numel
()
(
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
out_asm
,
)
=
moe_sorting_ck
(
topk_ids
,
topk_weights
,
E
,
model_dim
,
hidden_states_dtype
,
expert_mask
=
expert_mask
)
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
a1_scale
.
t
().
contiguous
(),
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
*
block_shape
,
smooth_scale
)
return
out_asm
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
list
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a1
,
dtype
=
hidden_states_dtype
)
def
rocm_aiter_asm_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
import
aiter.fused_moe_bf16_asm
as
rocm_aiter_asm_fmoe
from
aiter
import
ActivationType
assert
activation
in
[
"silu"
,
"gelu"
],
"The given activation:"
\
f
"
{
activation
}
"
\
" is not supported in"
\
" AITER."
if
activation
==
"silu"
:
aiter_activation
=
ActivationType
.
Silu
else
:
aiter_activation
=
ActivationType
.
Gelu
return
rocm_aiter_asm_fmoe
.
asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
fc1_smooth_scale
=
fc1_smooth_scale
,
fc2_smooth_scale
=
fc2_smooth_scale
,
a16
=
a16
,
activation
=
aiter_activation
)
def
rocm_aiter_asm_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_ck_moe_2stages_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_size
:
Optional
[
list
[
int
]]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
aiter.fused_moe_bf16_asm
import
ck_moe_2stages
return
ck_moe_2stages
(
a1
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_size
=
block_size
,
expert_mask
=
expert_mask
)
def
rocm_aiter_ck_moe_2stages_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_size
:
Optional
[
list
[
int
]]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake(
pass
def
rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
quant_method
:
int
=
QuantMethod
.
NO
.
value
,
doweight_stage1
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
activation
=
ActivationType
(
activation_method
)
quant_type
=
QuantType
(
quant_method
)
return
fused_moe
(
hidden_states
,
w1
,
w2
,
topk_weight
,
topk_ids
,
expert_mask
,
activation
,
quant_type
,
doweight_stage1
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
def
rocm_aiter_fused_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
quant_method
:
int
=
QuantMethod
.
NO
.
value
,
doweight_stage1
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
...
...
@@ -285,26 +195,10 @@ if current_platform.is_rocm():
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_f
moe_fp8_blockscale_g1u1
"
,
op_func
=
rocm_aiter_f
moe_fp8_blockscale_g1u1
_impl
,
op_name
=
"rocm_aiter_f
used_moe
"
,
op_func
=
rocm_aiter_f
used_moe
_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe"
,
op_func
=
rocm_aiter_asm_moe_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_asm_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_ck_moe_2stages"
,
op_func
=
rocm_aiter_ck_moe_2stages_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_ck_moe_2stages_fake
,
fake_impl
=
rocm_aiter_fused_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
...
...
@@ -373,32 +267,14 @@ def rocm_aiter_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
activation_method
=
(
ActivationMethod
.
SILU
if
activation
==
"silu"
else
ActivationMethod
.
GELU
)
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
# w8a8 block-scaled
if
block_shape
is
not
None
and
use_fp8_w8a8
:
assert
not
apply_router_weight_on_input
,
(
"apply_router_weight_on_input is not supported for block scaled moe"
)
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# The default block sizes are 128 in AITER.
block_shape
=
[
128
,
128
]
if
block_shape
is
None
else
block_shape
a1
,
a1_scale
=
per_token_group_quant_fp8
(
hidden_states
,
block_shape
[
1
])
return
torch
.
ops
.
vllm
.
rocm_aiter_fmoe_fp8_blockscale_g1u1
(
topk_ids
,
topk_weights
,
hidden_states
.
dtype
,
None
,
a1
,
w1
,
w2
,
w1_scale
,
w2_scale
,
a1_scale
,
block_shape
,
None
)
# w8a8 per-channel quantization
el
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
...
...
@@ -421,42 +297,23 @@ def rocm_aiter_fused_experts(
a16
=
False
,
per_tensor_quant_scale
=
None
,
expert_mask
=
None
,
activation_
str
=
activation
)
activation_
method
=
activation
_method
)
# w8a8 per-tensor activation per-tensor weight
elif
use_fp8_w8a8
:
else
:
quant_method
=
QuantMethod
.
NO
.
value
# w8a8 block-scaled
if
block_shape
is
not
None
and
use_fp8_w8a8
:
assert
not
apply_router_weight_on_input
,
(
"apply_router_weight_on_input is not supported for fp8_w8a8"
)
# - faster static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
if
a1_scale
is
not
None
and
a2_scale
is
not
None
:
return
torch
.
ops
.
vllm
.
rocm_aiter_ck_moe_2stages
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return
torch
.
ops
.
vllm
.
rocm_aiter_asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
fc1_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
a16
=
False
,
activation
=
activation
)
"apply_router_weight_on_input is
\
not supported for block scaled moe"
)
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
quant_method
=
QuantMethod
.
BLOCK_128x128
.
value
elif
use_fp8_w8a8
:
# Currently only per tensor quantization method is enabled.
quant_method
=
QuantMethod
.
PER_TENSOR
.
value
if
apply_router_weight_on_input
:
assert
(
topk_weights
.
dim
()
==
2
),
"`topk_weights` should be in shape (num_tokens, topk)"
...
...
@@ -465,16 +322,19 @@ def rocm_aiter_fused_experts(
topk
==
1
),
"Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states
=
hidden_states
*
topk_weights
.
to
(
hidden_states
.
dtype
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
)
return
torch
.
ops
.
vllm
.
rocm_aiter_ck_moe_2stages
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
return
torch
.
ops
.
vllm
.
rocm_aiter_fused_moe
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
quant_method
=
quant_method
,
activation_method
=
activation_method
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
doweight_stage1
=
apply_router_weight_on_input
)
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
...
...
@@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return
topk_weights
,
topk_indices
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
])
->
tuple
[
torch
.
Tensor
,
...]:
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
]
=
(
16
,
16
)
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
...
...
@@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor,
from
aiter.ops.shuffle
import
shuffle_weight
return
tuple
(
shuffle_weight
(
tensor
,
layout
=
layout
)
for
tensor
in
tensors
)
def
expand_weights
(
*
tensors
:
torch
.
Tensor
,
expansion_dims
:
list
[
int
])
->
tuple
[
torch
.
Tensor
,
...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A Tuple of tensors with expanded dimensions.
"""
assert
len
(
tensors
)
==
len
(
expansion_dims
),
\
"Number of tensors must match the number of expansion dimensions."
return
tuple
(
tensor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
((
-
1
,
dim
,
-
1
))
for
tensor
,
dim
in
zip
(
tensors
,
expansion_dims
))
\ No newline at end of file
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d260f799
...
...
@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
rocm_aiter_fused_experts
,
shuffle_weights
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
,
layout
=
(
16
,
16
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
d260f799
...
...
@@ -595,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Lazy import to avoid importing triton too early.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
expand_weights
,
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
self
.
rocm_aiter_moe_enabled
=
is_rocm_aiter_moe_enabled
()
...
...
@@ -627,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
,
layout
=
(
16
,
16
))
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
@@ -675,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
w13_scales
,
w2_scales
=
expand_weights
(
layer
.
w13_weight_scale
.
data
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
])
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_scales
.
contiguous
(),
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layout
=
(
16
,
16
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
@@ -760,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start
+=
shard_size
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
]
max_w13_scales
,
w2_scales
=
expand_weights
(
max_w13_scales
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
expansion_dims
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layout
=
(
32
,
32
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment