"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "fefbf8f74bdedeb4a6066f0b958a52194de91240"
Unverified Commit 88799448 authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations (#5228)

parent a879811c
...@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] ...@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip() _is_hip = is_hip()
if _is_hip: if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -487,7 +488,7 @@ class Fp8MoEMethod: ...@@ -487,7 +488,7 @@ class Fp8MoEMethod:
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = ( params_dtype = (
torch.int32 torch.uint32
if get_bool_env_var("USE_INT4_WEIGHT") if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn else torch.float8_e4m3fn
) )
...@@ -822,12 +823,14 @@ class Fp8MoEMethod: ...@@ -822,12 +823,14 @@ class Fp8MoEMethod:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation # Weight Permutation
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), # permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data), # permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -867,12 +870,14 @@ class Fp8MoEMethod: ...@@ -867,12 +870,14 @@ class Fp8MoEMethod:
if get_bool_env_var("CK_MOE"): if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), # permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data), # permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -928,7 +933,7 @@ class Fp8MoEMethod: ...@@ -928,7 +933,7 @@ class Fp8MoEMethod:
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return asm_moe( return ck_moe_2stages_win4(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -936,15 +941,17 @@ class Fp8MoEMethod: ...@@ -936,15 +941,17 @@ class Fp8MoEMethod:
topk_ids, topk_ids,
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=activation, activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
) )
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant: if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe( return asm_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -957,7 +964,7 @@ class Fp8MoEMethod: ...@@ -957,7 +964,7 @@ class Fp8MoEMethod:
expert_mask=None, expert_mask=None,
) )
else: else:
return asm_moe( return ck_moe_2stages(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -965,6 +972,11 @@ class Fp8MoEMethod: ...@@ -965,6 +972,11 @@ class Fp8MoEMethod:
topk_ids, topk_ids,
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
) )
else: else:
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment