"tests/vscode:/vscode.git/clone" did not exist on "6459a688ae15d797dd4d0586f2f8ad2e46d58145"
Unverified Commit f445a1d9 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Fix Llama 4 FP8 accuracy issues on MI300X (#7699)

parent e5638573
...@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip): ...@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip):
if _use_aiter: if _use_aiter:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache
from typing import Optional
import torch
from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
class ActivationMethod(IntEnum):
# This allows interfacing with AITER ActivationType enum
# without importing the ActivationType enum from AITER globally.
SILU = 0
GELU = 1
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if _use_aiter:
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
)
def rocm_fused_experts_tkw1(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
activation_method = (
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
)
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
assert topk_weights.shape[-1] == 1, (
"Only support topk=1 when" " `apply_router_weight_on_input` is True"
)
return torch.ops.sglang.rocm_aiter_asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=None,
activation_method=activation_method,
)
else:
assert False, "This should not be called."
...@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs from sglang.srt.utils import (
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
set_weight_attrs,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -29,6 +36,13 @@ if TYPE_CHECKING: ...@@ -29,6 +36,13 @@ if TYPE_CHECKING:
CompressedTensorsConfig, CompressedTensorsConfig,
) )
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter.ops.shuffle import shuffle_weight
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
try: try:
import vllm import vllm
...@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
if _use_aiter:
with torch.no_grad():
# Pre-shuffle weights
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
return fused_experts( if (
x, _use_aiter
layer.w13_weight, and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
layer.w2_weight, and moe_runner_config.apply_router_weight_on_input
topk_output=topk_output, ):
moe_runner_config=moe_runner_config, topk_weights, topk_ids, _ = topk_output
use_fp8_w8a8=True, return rocm_fused_experts_tkw1(
per_channel_quant=self.weight_quant.strategy hidden_states=x,
== QuantizationStrategy.CHANNEL, w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale, w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale, topk_weights=topk_weights,
a1_scale=layer.w13_input_scale, topk_ids=topk_ids,
a2_scale=layer.w2_input_scale, activation=moe_runner_config.activation,
) apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
......
...@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (_use_aiter): using column-wise scaling # ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
......
...@@ -2228,7 +2228,10 @@ class ServerArgs: ...@@ -2228,7 +2228,10 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels # use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16" self.dtype = "bfloat16"
elif "Llama4" in model_arch: elif "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model" assert self.attention_backend in {
"fa3",
"aiter",
}, "fa3 or aiter is required for Llama4 model"
elif model_arch in [ elif model_arch in [
"Gemma2ForCausalLM", "Gemma2ForCausalLM",
"Gemma3ForCausalLM", "Gemma3ForCausalLM",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment