Unverified Commit 88799448 authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations (#5228)

parent a879811c
......@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip()
if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda()
......@@ -487,7 +488,7 @@ class Fp8MoEMethod:
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = (
torch.int32
torch.uint32
if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
......@@ -822,12 +823,14 @@ class Fp8MoEMethod:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
......@@ -867,12 +870,14 @@ class Fp8MoEMethod:
if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
......@@ -928,7 +933,7 @@ class Fp8MoEMethod:
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -936,15 +941,17 @@ class Fp8MoEMethod:
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
......@@ -957,7 +964,7 @@ class Fp8MoEMethod:
expert_mask=None,
)
else:
return asm_moe(
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -965,6 +972,11 @@ class Fp8MoEMethod:
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
# Expert fusion with FP8 quantization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment