Unverified Commit e50109f2 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)

parent 69adc4f8
...@@ -54,14 +54,11 @@ _is_npu = is_npu() ...@@ -54,14 +54,11 @@ _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not _is_npu: if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
if _use_aiter: if _use_aiter:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
......
...@@ -39,11 +39,20 @@ _is_hip = is_hip() ...@@ -39,11 +39,20 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available: elif _is_cpu and _is_cpu_amx_available:
pass pass
elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
...@@ -1521,11 +1530,7 @@ def fused_experts_impl( ...@@ -1521,11 +1530,7 @@ def fused_experts_impl(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
): ):
padded_size = padding_size padded_size = padding_size
if ( if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
):
padded_size = 0 padded_size = 0
# Check constraints. # Check constraints.
...@@ -1723,6 +1728,17 @@ def fused_experts_impl( ...@@ -1723,6 +1728,17 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor, routed_scaling_factor,
) )
elif _is_hip:
if _use_aiter:
moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else: else:
vllm_ops.moe_sum( vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
......
...@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -32,8 +32,9 @@ _is_cuda = is_cuda() ...@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_hip = is_hip()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip and (_use_aiter or _use_hip_int4): if _is_hip and (_use_aiter or _use_hip_int4):
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper ...@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import ( from sglang.srt.utils import (
align, align,
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var,
get_device_core_count, get_device_core_count,
get_device_name, get_device_name,
is_cpu, is_cpu,
...@@ -39,6 +40,7 @@ from sglang.srt.utils import ( ...@@ -39,6 +40,7 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -47,6 +49,22 @@ if _is_cuda: ...@@ -47,6 +49,22 @@ if _is_cuda:
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
) )
if _is_hip:
if _use_aiter:
try:
from aiter import ( # v0.1.3
dynamic_per_tensor_quant,
dynamic_per_token_scaled_quant,
static_per_tensor_quant,
)
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else:
try:
import vllm._C
except ImportError:
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1116,16 +1134,10 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8( ...@@ -1116,16 +1134,10 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
def scaled_fp8_quant( """
input: torch.Tensor, Quantize input tensor to FP8 (8-bit floating point) format.
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args: Args:
input (torch.Tensor): Input tensor to be quantized input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically. If None, scales will be computed dynamically.
...@@ -1136,14 +1148,22 @@ def scaled_fp8_quant( ...@@ -1136,14 +1148,22 @@ def scaled_fp8_quant(
- True: compute scale per token - True: compute scale per token
- False: compute single scale per tensor - False: compute single scale per tensor
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing: Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input - quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization - scale_tensor: The scaling factors used for quantization
Raises: Raises:
AssertionError: If input is not 2D or if static scale's numel != 1 AssertionError: If input is not 2D or if static scale's numel != 1
""" """
if _is_hip:
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape shape = input.shape
if num_token_padding: if num_token_padding:
...@@ -1153,7 +1173,54 @@ def scaled_fp8_quant( ...@@ -1153,7 +1173,54 @@ def scaled_fp8_quant(
if scale is None: if scale is None:
# Dynamic scaling # Dynamic scaling
if use_per_token_if_dynamic: if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
if _use_aiter:
dynamic_per_token_scaled_quant(output, input, scale)
else:
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, None
)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
if _use_aiter:
dynamic_per_tensor_quant(output, input, scale)
else:
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
if _use_aiter:
static_per_tensor_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
else:
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale) sgl_per_token_quant_fp8(input, output, scale)
else: else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
...@@ -1162,7 +1229,9 @@ def scaled_fp8_quant( ...@@ -1162,7 +1229,9 @@ def scaled_fp8_quant(
) # False for dynamic ) # False for dynamic
else: else:
# Static scaling # Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}" assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8( sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True input, output, scale, is_static=True
) # True for static ) # True for static
......
...@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter: if _use_aiter:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -21,8 +21,9 @@ _is_cuda = is_cuda() ...@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_hip = is_hip()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -3,8 +3,13 @@ ...@@ -3,8 +3,13 @@
import pytest import pytest
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
...@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None: ...@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale): def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to # The reference implementation that fully aligns to
# the kernel being tested. # the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal() scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn) qweight = qweight.to(fp8_dtype)
return qweight return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype): def dequantize_per_tensor(tensor, inv_scale, dtype):
...@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None: ...@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
) )
if is_cuda: if _is_cuda or _is_hip:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None: def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale): def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to # The reference implementation that fully aligns to
# the kernel being tested. # the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal() scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp( qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max min=finfo.min, max=finfo.max
) )
qweight = qweight.to(torch.float8_e4m3fn) qweight = qweight.to(fp8_dtype)
return qweight return qweight
def dequantize_per_token(tensor, inv_scale, dtype): def dequantize_per_token(tensor, inv_scale, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment