Unverified Commit b819381f authored by HAI's avatar HAI Committed by GitHub
Browse files

AITER backend extension and workload optimizations (#6838)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 562f279a
......@@ -72,7 +72,7 @@ jobs:
- name: Evaluate accuracy (TP=2)
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
mla-test-1-gpu-amd:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
......@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` |
| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
......
......@@ -20,10 +20,11 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import (
......@@ -33,7 +34,10 @@ if _is_cuda:
rmsnorm,
)
if _is_hip:
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm
logger = logging.getLogger(__name__)
......@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
if _use_aiter:
self._forward_method = self.forward_aiter
def forward_cuda(
self,
......@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_aiter(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual,
residual_out,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
......
......@@ -1332,7 +1332,7 @@ def fused_experts_impl(
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
):
padded_size = 0
......
......@@ -28,8 +28,9 @@ else:
import logging
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip:
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
......@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
......@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor=routed_scaling_factor,
)
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
......
......@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip:
from aiter import ActivationType, QuantType
......@@ -487,7 +487,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
......@@ -512,7 +512,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -641,7 +641,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
......@@ -668,7 +668,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
......@@ -700,7 +700,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer)
return
......@@ -731,7 +731,7 @@ class Fp8MoEMethod:
)
layer.w2_input_scale = None
if _is_hip and use_aiter_moe:
if _use_aiter:
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
......@@ -853,7 +853,7 @@ class Fp8MoEMethod:
return
def process_weights_hip_int4(self, layer: Module):
# TODO: and use_aiter_moe: add after triton kernel added
# TODO: _use_aiter: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
......@@ -900,7 +900,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import
)
if use_aiter_moe:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
......@@ -911,7 +911,7 @@ class Fp8MoEMethod:
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (use_aiter_moe): using column-wise scaling
# ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"):
......@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
if use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe
if _use_hip_int4:
# TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages(
x,
......@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
),
)
if use_aiter_moe:
if _use_aiter:
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
# TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
return asm_moe(
x,
layer.w13_weight,
......
......@@ -38,11 +38,10 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale
if _use_aiter:
from aiter import gemm_a8w8_blockscale_CK
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
......@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe:
elif _use_aiter:
return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
......@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input_2d.dtype,
device=q_input.device,
output = gemm_a8w8_blockscale_CK(
q_input, weight, x_scale, weight_scale, dtype=input.dtype
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None:
output += bias
......
......@@ -355,6 +355,15 @@ class ModelRunner:
# MLA architecture
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if (
head_num == 128 or head_num == 16
) and self.spec_algorithm.is_none():
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
logger.info(
......@@ -363,6 +372,7 @@ class ModelRunner:
elif self.use_mla_backend:
if server_args.device != "cpu":
if server_args.attention_backend in [
"aiter",
"flashinfer",
"fa3",
"triton",
......
......@@ -105,6 +105,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
......@@ -120,6 +121,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
if _use_aiter:
from aiter.rotary_embedding import get_rope
logger = logging.getLogger(__name__)
......@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.alt_stream = alt_stream
self.attn_mha.kv_b_proj = None
self.w_kc = None
self.w_vc = None
......@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif self.attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
......@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
......
#!/bin/bash
set -euo pipefail
# Default working directory
WORKDIR="/sglang-checkout/test/srt"
ENV_ARGS=(
-e SGLANG_AMD_CI=1
-e SGLANG_IS_IN_CI=1
-e SGLANG_AITER_MOE=1
declare -A ENV_MAP=(
[SGLANG_AMD_CI]=1
[SGLANG_IS_IN_CI]=1
[SGLANG_USE_AITER]=1
)
# Parse optional -w/--workdir and -e ENV=VAL flags
# Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do
case "$1" in
-w|--workdir)
......@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
shift 2
;;
-e)
ENV_ARGS+=("-e" "$2")
IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2
;;
--)
......@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
esac
done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec
docker exec \
-w "$WORKDIR" \
......
......@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
os.environ["HF_HUB_DISABLE_XET"] = (
"1" if model in DISABLE_HF_XET_MODELS else "0"
)
os.environ["SGLANG_AITER_MOE"] = (
os.environ["SGLANG_USE_AITER"] = (
"0" if model in TRITON_MOE_MODELS else "1"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment