Unverified Commit 51d25405 authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm: update aiter and its usage to fused moe (bloat16, fp8, fp8 block-quant) (#4053)

parent e0a2c963
...@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0" ...@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="dev/testx" ARG AITER_COMMIT="testx"
RUN git clone ${SGL_REPO} \ RUN git clone ${SGL_REPO} \
&& cd sglang \ && cd sglang \
......
...@@ -51,7 +51,7 @@ srt = [ ...@@ -51,7 +51,7 @@ srt = [
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,
......
...@@ -29,6 +29,9 @@ import logging ...@@ -29,6 +29,9 @@ import logging
is_hip_ = is_hip() is_hip_ = is_hip()
if is_hip_:
from aiter import ck_moe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -173,18 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -173,18 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) )
if is_hip_ and get_bool_env_var("CK_MOE"): if is_hip_ and get_bool_env_var("CK_MOE"):
import aiter
from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
return ck_moe(
return fused_experts_ck( x,
hidden_states=x, layer.w13_weight,
w1=layer.w13_weight, layer.w2_weight,
w2=layer.w2_weight, topk_weights,
topk_weights=topk_weights, topk_ids,
topk_ids=topk_ids, None,
None,
None,
None,
32,
None,
activation,
) )
else: else:
return fused_experts( return fused_experts(
......
...@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] ...@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_ = is_hip() is_hip_ = is_hip()
if is_hip_:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -533,6 +537,20 @@ class Fp8MoEMethod: ...@@ -533,6 +537,20 @@ class Fp8MoEMethod:
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if is_hip_ and get_bool_env_var("CK_MOE"):
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale1", w13_weight_scale1)
layer.register_parameter("w2_weight_scale1", w2_weight_scale1)
# Add the quantization method used (per tensor/grouped/channel) # Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly # to ensure the weight scales are loaded in properly
extra_weight_attrs.update( extra_weight_attrs.update(
...@@ -602,6 +620,15 @@ class Fp8MoEMethod: ...@@ -602,6 +620,15 @@ class Fp8MoEMethod:
w2_weight_scale, requires_grad=False w2_weight_scale, requires_grad=False
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if get_bool_env_var("CK_MOE"):
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
return return
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
...@@ -640,6 +667,9 @@ class Fp8MoEMethod: ...@@ -640,6 +667,9 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (CK_MOE): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("MOE_PADDING"): elif get_bool_env_var("MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
...@@ -744,6 +774,9 @@ class Fp8MoEMethod: ...@@ -744,6 +774,9 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (CK_MOE): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("MOE_PADDING"): elif get_bool_env_var("MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
...@@ -790,34 +823,38 @@ class Fp8MoEMethod: ...@@ -790,34 +823,38 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip_ and get_bool_env_var("CK_MOE"): if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
import aiter # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
return fused_experts_ck( return asm_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights,
topk_ids=topk_ids, topk_ids,
use_fp8_w8a8=True, layer.w13_weight_scale_inv,
w1_scale=( layer.w2_weight_scale_inv,
layer.w13_weight_scale_inv None,
if self.block_quant None,
else layer.w13_weight_scale False,
), None,
w2_scale=( block_shape=tuple(self.quant_config.weight_block_size),
layer.w2_weight_scale_inv expert_mask=None,
if self.block_quant )
else layer.w2_weight_scale else:
), return asm_moe(
a1_scale=layer.w13_input_scale, x,
a2_scale=layer.w2_input_scale, layer.w13_weight,
) layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
None,
None,
False,
)
else: else:
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment