Unverified Commit e50109f2 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)

parent 69adc4f8
......@@ -54,14 +54,11 @@ _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not _is_npu:
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
if _use_aiter:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
......
......@@ -39,11 +39,20 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
......@@ -1521,11 +1530,7 @@ def fused_experts_impl(
routed_scaling_factor: Optional[float] = None,
):
padded_size = padding_size
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
):
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
padded_size = 0
# Check constraints.
......@@ -1723,6 +1728,17 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
elif _is_hip:
if _use_aiter:
moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
......
......@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize,
replace_parameter,
)
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
......@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_hip = is_hip()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip and (_use_aiter or _use_hip_int4):
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import (
align,
direct_register_custom_op,
get_bool_env_var,
get_device_core_count,
get_device_name,
is_cpu,
......@@ -39,6 +40,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import (
......@@ -47,6 +49,22 @@ if _is_cuda:
sgl_per_token_quant_fp8,
)
if _is_hip:
if _use_aiter:
try:
from aiter import ( # v0.1.3
dynamic_per_tensor_quant,
dynamic_per_token_scaled_quant,
static_per_tensor_quant,
)
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else:
try:
import vllm._C
except ImportError:
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
logger = logging.getLogger(__name__)
......@@ -1116,58 +1134,109 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
if _is_hip:
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
if _use_aiter:
dynamic_per_token_scaled_quant(output, input, scale)
else:
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, None
)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
if _use_aiter:
dynamic_per_tensor_quant(output, input, scale)
else:
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
if _use_aiter:
static_per_tensor_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
return output, scale
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
sgl_per_token_quant_fp8(input, output, scale)
else:
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
input, output, scale, is_static=True
) # True for static
return output, scale
return output, scale
fp8_autotune = triton.autotune(
......
......@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
......
......@@ -12,7 +12,7 @@ import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_hip = is_hip()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -3,8 +3,13 @@
import pytest
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
......@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
qweight = qweight.to(fp8_dtype)
return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype):
......@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
)
if is_cuda:
if _is_cuda or _is_hip:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max
)
qweight = qweight.to(torch.float8_e4m3fn)
qweight = qweight.to(fp8_dtype)
return qweight
def dequantize_per_token(tensor, inv_scale, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment