Unverified Commit b70957fc authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[refactor] slightly tidy fp8 module (#5993)

parent e444c13f
......@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
)
logger = logging.getLogger(__name__)
......@@ -654,10 +654,7 @@ def grouped_gemm_triton(
if block_shape is not None:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
else:
a, scale_a = per_token_group_quant_fp8(a, block_k)
a, scale_a = per_token_group_quant_fp8(a, block_k)
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
......
......@@ -10,16 +10,14 @@ import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
is_cuda,
is_fp8_fnuz,
per_tensor_dequantize,
replace_parameter,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_cuda, set_weight_attrs
_is_cuda = is_cuda()
......
......@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
__all__ = ["CompressedTensorsW8A8Fp8"]
......
......@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8,
scaled_fp8_quant,
)
......@@ -71,6 +73,11 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip:
from aiter import ActivationType, QuantType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
......@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip:
if _is_fp8_fnuz:
# activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=None,
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
weight_scale, requires_grad=False
)
layer.input_scale = None
else:
layer.weight = torch.nn.Parameter(
layer.weight.data, requires_grad=False
)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
weight_scale, requires_grad=False
)
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
......@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip:
if _is_fp8_fnuz:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
......@@ -482,11 +485,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = (
torch.uint32
if get_bool_env_var("SGLANG_INT4_WEIGHT")
else torch.float8_e4m3fn
)
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
......@@ -511,7 +510,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and use_hip_int4:
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -583,9 +582,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if (
_is_hip
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
......@@ -612,7 +609,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and use_hip_int4:
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
......@@ -644,14 +641,14 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and use_hip_int4:
self.process_weights_hip_int4(layer)
return
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip:
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
......@@ -675,20 +672,19 @@ class Fp8MoEMethod:
)
layer.w2_input_scale = None
if get_bool_env_var("SGLANG_AITER_MOE"):
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
if _is_hip and use_aiter_moe:
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
......@@ -742,7 +738,7 @@ class Fp8MoEMethod:
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip:
if _is_fp8_fnuz:
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
......@@ -798,7 +794,7 @@ class Fp8MoEMethod:
return
def process_weights_hip_int4(self, layer: Module):
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
# TODO: and use_aiter_moe: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
......@@ -845,7 +841,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import
)
if get_bool_env_var("SGLANG_AITER_MOE"):
if use_aiter_moe:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
......@@ -856,7 +852,7 @@ class Fp8MoEMethod:
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
# ROCm (use_aiter_moe): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"):
......@@ -908,59 +904,16 @@ class Fp8MoEMethod:
)
if _is_hip:
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
if get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
ret = self.maybe_apply_hip_fused_experts(
layer,
x,
topk_weights,
topk_ids,
activation,
no_combine,
)
if ret is not None:
return ret
# Expert fusion with FP8 quantization
return fused_experts(
......@@ -987,6 +940,68 @@ class Fp8MoEMethod:
no_combine=no_combine,
)
def maybe_apply_hip_fused_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
if use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if use_aiter_moe:
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
return None
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
......
......@@ -16,6 +16,7 @@ import functools
import json
import logging
import os
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
import torch
......@@ -34,12 +35,6 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
if _is_cuda:
from sgl_kernel import (
......@@ -54,6 +49,24 @@ if _is_cuda:
logger = logging.getLogger(__name__)
@lru_cache()
def is_fp8_fnuz() -> bool:
if _is_hip:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
return False
if is_fp8_fnuz():
fp8_dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
else:
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max
fp8_min = -fp8_max
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
......@@ -198,7 +211,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
......@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
......@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = _fp8_type,
dtype: torch.dtype = fp8_dtype,
):
assert x.is_contiguous(), "`x` is not contiguous"
......@@ -384,7 +397,7 @@ def static_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
......@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs(
return None
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
return _w8a8_block_fp8_matmul
if _is_hip:
def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
N, META["BLOCK_SIZE_N"]
)
num_workgroups <= get_device_core_count()
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
return _w8a8_block_fp8_matmul_unrolledx4
else:
return _w8a8_block_fp8_matmul
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul(
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
# deepgemm only support bf16
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if supports_custom_op():
......@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul(
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel[grid](
A,
......@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8(
and x_s_out.device == x.device
)
x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_q = x.new_empty(x.size(), dtype=fp8_dtype)
num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size)
......@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
def per_tensor_quant_mla_deep_gemm_masked_fp8(
def per_token_group_quant_mla_deep_gemm_masked_fp8(
x: torch.Tensor,
group_size: int = 128,
eps: float = 1e-12,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: torch.dtype = fp8_dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with per-token-group-quantization
......@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
b, m, k = x.shape
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
num_tiles_k = k // group_size
......@@ -1043,10 +1062,9 @@ def scaled_fp8_quant(
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None:
# Dynamic scaling
......
......@@ -14,6 +14,9 @@ except ImportError:
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
fp8_max,
is_fp8_fnuz,
per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
......@@ -30,8 +33,11 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale
if _is_cuda:
......@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
except ValueError:
_TORCH_VERSION_TUPLE = (0, 0, 0)
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
)
def use_rowwise_torch_scaled_mm():
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
except ValueError:
_TORCH_VERSION_TUPLE = (0, 0, 0)
if _is_hip:
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
return False
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
def cutlass_fp8_supported():
......@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
elif _is_hip and use_aiter_moe:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
......@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
def input_to_float8(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
if _is_fp8_fnuz:
dtype = fp8_dtype
fp_max = fp8_max
else:
finfo = torch.finfo(dtype)
fp_max = finfo.max
scale = fp_max / amax
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
......
......@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_hip
_is_hip = is_hip()
logger = logging.getLogger(__name__)
......@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
......@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
if is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
......@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
if is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
......
......@@ -14,11 +14,6 @@ if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
def is_fp8_fnuz() -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
......
......@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip, set_weight_attrs
from sglang.srt.utils import set_weight_attrs
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
class W8A8Fp8Config(QuantizationConfig):
......@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
if self.quantization_config.is_checkpoint_fp8_serialized:
weight_scale = layer.weight_scale.detach()
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
if _is_hip:
if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
......@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
else:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
# Update the layer with the new values.
......@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
......
......@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
else:
from vllm._custom_ops import awq_dequantize
from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__)
......
......@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
......@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
......@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
per_token_group_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1)
)
)
attn_bmm_output = attn_output.new_empty(
......
......@@ -7,9 +7,9 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
......@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
with torch.inference_mode():
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8(
x, group_size
)
out = out[:, :num_tokens, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment