"tests/vscode:/vscode.git/clone" did not exist on "1506560e1ecc6a2c56b51e23766632a9326aeeca"
Unverified Commit b70957fc authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[refactor] slightly tidy fp8 module (#5993)

parent e444c13f
...@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda ...@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -654,10 +654,7 @@ def grouped_gemm_triton( ...@@ -654,10 +654,7 @@ def grouped_gemm_triton(
if block_shape is not None: if block_shape is not None:
assert len(block_shape) == 2 assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1] block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda: a, scale_a = per_token_group_quant_fp8(a, block_k)
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
else:
a, scale_a = per_token_group_quant_fp8(a, block_k)
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
......
...@@ -10,16 +10,14 @@ import torch ...@@ -10,16 +10,14 @@ import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
is_cuda,
is_fp8_fnuz,
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import is_cuda, set_weight_attrs
_is_cuda = is_cuda() _is_cuda = is_cuda()
......
...@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import ( ...@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
from sglang.srt.layers.quantization.compressed_tensors.schemes import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale from sglang.srt.layers.quantization.utils import requantize_with_max_scale
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
......
...@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant, scaled_fp8_quant,
) )
...@@ -71,6 +73,11 @@ from sglang.srt.utils import ( ...@@ -71,6 +73,11 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip: if _is_hip:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
...@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip: if _is_fp8_fnuz:
# activation_scheme: dynamic # activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale_inv, weight_scale=layer.weight_scale_inv,
input_scale=None, input_scale=None,
) )
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
weight_scale, requires_grad=False
)
layer.input_scale = None layer.input_scale = None
else: else:
layer.weight = torch.nn.Parameter( weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight.data, requires_grad=False layer.weight = torch.nn.Parameter(weight, requires_grad=False)
) layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv = torch.nn.Parameter( weight_scale, requires_grad=False
layer.weight_scale_inv.data, requires_grad=False )
)
return return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
...@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip: if _is_fp8_fnuz:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=weight_scale, weight_scale=weight_scale,
...@@ -482,11 +485,7 @@ class Fp8MoEMethod: ...@@ -482,11 +485,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = ( params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
torch.uint32
if get_bool_env_var("SGLANG_INT4_WEIGHT")
else torch.float8_e4m3fn
)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.block_quant: if self.block_quant:
block_n, block_k = ( block_n, block_k = (
...@@ -511,7 +510,7 @@ class Fp8MoEMethod: ...@@ -511,7 +510,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): if _is_hip and use_hip_int4:
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -583,9 +582,7 @@ class Fp8MoEMethod: ...@@ -583,9 +582,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if ( if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
_is_hip
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
...@@ -612,7 +609,7 @@ class Fp8MoEMethod: ...@@ -612,7 +609,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): if _is_hip and use_hip_int4:
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
...@@ -644,14 +641,14 @@ class Fp8MoEMethod: ...@@ -644,14 +641,14 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): if _is_hip and use_hip_int4:
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip: if _is_fp8_fnuz:
# activation_scheme: dynamic # activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight, weight=layer.w13_weight,
...@@ -675,20 +672,19 @@ class Fp8MoEMethod: ...@@ -675,20 +672,19 @@ class Fp8MoEMethod:
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if get_bool_env_var("SGLANG_AITER_MOE"): if _is_hip and use_aiter_moe:
# Pre-shuffle weights # Pre-shuffle weights
layer.w13_weight.data = shuffle_weight( layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16) layer.w13_weight.contiguous(), (16, 16)
) )
layer.w2_weight.data = shuffle_weight( layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16) layer.w2_weight.contiguous(), (16, 16)
) )
return return
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW) # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
...@@ -742,7 +738,7 @@ class Fp8MoEMethod: ...@@ -742,7 +738,7 @@ class Fp8MoEMethod:
) )
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_hip: if _is_fp8_fnuz:
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = ( w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -798,7 +794,7 @@ class Fp8MoEMethod: ...@@ -798,7 +794,7 @@ class Fp8MoEMethod:
return return
def process_weights_hip_int4(self, layer: Module): def process_weights_hip_int4(self, layer: Module):
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added # TODO: and use_aiter_moe: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation # Weight Permutation
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
...@@ -845,7 +841,7 @@ class Fp8MoEMethod: ...@@ -845,7 +841,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import padding_size, # Avoid circular import
) )
if get_bool_env_var("SGLANG_AITER_MOE"): if use_aiter_moe:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
...@@ -856,7 +852,7 @@ class Fp8MoEMethod: ...@@ -856,7 +852,7 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (SGLANG_AITER_MOE): using column-wise scaling # ROCm (use_aiter_moe): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"): elif get_bool_env_var("SGLANG_MOE_PADDING"):
...@@ -908,59 +904,16 @@ class Fp8MoEMethod: ...@@ -908,59 +904,16 @@ class Fp8MoEMethod:
) )
if _is_hip: if _is_hip:
if get_bool_env_var("SGLANG_INT4_WEIGHT"): ret = self.maybe_apply_hip_fused_experts(
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE") layer,
assert not no_combine, f"{no_combine=} is not supported." x,
return ck_moe_2stages( topk_weights,
x, topk_ids,
layer.w13_weight, activation,
layer.w2_weight, no_combine,
topk_weights, )
topk_ids, if ret is not None:
QuantType.per_Token, return ret
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
if get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
...@@ -987,6 +940,68 @@ class Fp8MoEMethod: ...@@ -987,6 +940,68 @@ class Fp8MoEMethod:
no_combine=no_combine, no_combine=no_combine,
) )
def maybe_apply_hip_fused_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
if use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if use_aiter_moe:
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
return None
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
......
...@@ -16,6 +16,7 @@ import functools ...@@ -16,6 +16,7 @@ import functools
import json import json
import logging import logging
import os import os
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
...@@ -34,12 +35,6 @@ from sglang.srt.utils import ( ...@@ -34,12 +35,6 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -54,6 +49,24 @@ if _is_cuda: ...@@ -54,6 +49,24 @@ if _is_cuda:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache()
def is_fp8_fnuz() -> bool:
if _is_hip:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
return False
if is_fp8_fnuz():
fp8_dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
else:
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max
fp8_min = -fp8_max
if supports_custom_op(): if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt( def deep_gemm_fp8_fp8_bf16_nt(
...@@ -198,7 +211,7 @@ def per_token_group_quant_fp8( ...@@ -198,7 +211,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
if column_major_scales: if column_major_scales:
...@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8( ...@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
if column_major_scales: if column_major_scales:
if scale_tma_aligned: if scale_tma_aligned:
# aligned to 4 * sizeof(float) # aligned to 4 * sizeof(float)
...@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8( ...@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
dtype: torch.dtype = _fp8_type, dtype: torch.dtype = fp8_dtype,
): ):
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
...@@ -384,7 +397,7 @@ def static_quant_fp8( ...@@ -384,7 +397,7 @@ def static_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale" assert x_s.numel() == 1, "only supports per-tensor scale"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
M = x.numel() // x.shape[-1] M = x.numel() // x.shape[-1]
N = x.shape[-1] N = x.shape[-1]
if repeat_scale: if repeat_scale:
...@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs( ...@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs(
return None return None
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
return _w8a8_block_fp8_matmul
if _is_hip:
def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
N, META["BLOCK_SIZE_N"]
)
num_workgroups <= get_device_core_count()
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
return _w8a8_block_fp8_matmul_unrolledx4
else:
return _w8a8_block_fp8_matmul
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul( ...@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul(
C_shape = A.shape[:-1] + (N,) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
# deepgemm only support bf16 # deepgemm only support bf16
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if supports_custom_op(): if supports_custom_op():
...@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul( ...@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul(
else: else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
else: else:
kernel = ( configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
_w8a8_block_fp8_matmul_unrolledx4 if configs:
if (_is_hip == True and num_workgroups <= get_device_core_count()) # If an optimal configuration map has been found, look up the
else _w8a8_block_fp8_matmul # optimal config
) config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel[grid]( kernel[grid](
A, A,
...@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8( ...@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8(
and x_s_out.device == x.device and x_s_out.device == x.device
) )
x_q = x.new_empty(x.size(), dtype=_fp8_type) x_q = x.new_empty(x.size(), dtype=fp8_dtype)
num_head, num_seq, head_size = x.shape num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size) BLOCK_SIZE = triton.next_power_of_2(head_size)
...@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8( ...@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
tl.store(y_s_ptr + gid * y_s_stride_g, y_s) tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
def per_tensor_quant_mla_deep_gemm_masked_fp8( def per_token_group_quant_mla_deep_gemm_masked_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int = 128, group_size: int = 128,
eps: float = 1e-12, eps: float = 1e-12,
dtype: torch.dtype = torch.float8_e4m3fn, dtype: torch.dtype = fp8_dtype,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This function quantizes input values to float8 values with per-token-group-quantization This function quantizes input values to float8 values with per-token-group-quantization
...@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8( ...@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
""" """
assert x.dim() == 3, "`x` is not a 3d-tensor" assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
b, m, k = x.shape b, m, k = x.shape
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
num_tiles_k = k // group_size num_tiles_k = k // group_size
...@@ -1043,10 +1062,9 @@ def scaled_fp8_quant( ...@@ -1043,10 +1062,9 @@ def scaled_fp8_quant(
""" """
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype) output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
if scale is None: if scale is None:
# Dynamic scaling # Dynamic scaling
......
...@@ -14,6 +14,9 @@ except ImportError: ...@@ -14,6 +14,9 @@ except ImportError:
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
fp8_max,
is_fp8_fnuz,
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant, scaled_fp8_quant,
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
...@@ -30,8 +33,11 @@ from sglang.srt.utils import ( ...@@ -30,8 +33,11 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
if _is_cuda: if _is_cuda:
...@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K ...@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0]
try: def use_rowwise_torch_scaled_mm():
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3])) _TORCH_VERSION = torch.__version__.split("+")[0]
except ValueError: try:
_TORCH_VERSION_TUPLE = (0, 0, 0) _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
except ValueError:
# The condition to determine if it is on a platform that supports _TORCH_VERSION_TUPLE = (0, 0, 0)
# torch._scaled_mm rowwise feature. if _is_hip:
# The condition is determined once as the operations # The condition to determine if it is on a platform that supports
# are time consuming. # torch._scaled_mm rowwise feature.
USE_ROWWISE_TORCH_SCALED_MM = ( # The condition is determined once as the operations
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0) # are time consuming.
) return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
return False
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
def cutlass_fp8_supported(): def cutlass_fp8_supported():
...@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear( ...@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm( output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
) )
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): elif _is_hip and use_aiter_moe:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
...@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear( ...@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
def input_to_float8( def input_to_float8(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn x: torch.Tensor, dtype: torch.dtype = fp8_dtype
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values with tensor-wise quantization.""" """This function quantizes input values to float8 values with tensor-wise quantization."""
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax() min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
fp8_max = finfo.max
if _is_hip: if _is_fp8_fnuz:
dtype = torch.float8_e4m3fnuz dtype = fp8_dtype
fp8_max = 224.0 fp_max = fp8_max
scale = fp8_max / amax else:
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max) finfo = torch.finfo(dtype)
fp_max = finfo.max
scale = fp_max / amax
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
......
...@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_hip
_is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
) )
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def apply(self, layer: torch.nn.Module) -> torch.Tensor: def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
...@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# We prefer to use separate k_scale and v_scale if present # We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist() k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz(): if is_fp8_fnuz():
k_scale *= 2 k_scale *= 2
v_scale *= 2 v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0: elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
...@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale) scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist() k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz(): if is_fp8_fnuz():
k_scale *= 2 k_scale *= 2
v_scale *= 2 v_scale *= 2
......
...@@ -14,11 +14,6 @@ if not _is_cuda: ...@@ -14,11 +14,6 @@ if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
def is_fp8_fnuz() -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def is_layer_skipped( def is_layer_skipped(
prefix: str, prefix: str,
ignored_layers: List[str], ignored_layers: List[str],
......
...@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.utils import is_hip, set_weight_attrs from sglang.srt.utils import set_weight_attrs
_is_hip = is_hip() _is_fp8_fnuz = is_fp8_fnuz()
class W8A8Fp8Config(QuantizationConfig): class W8A8Fp8Config(QuantizationConfig):
...@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
if self.quantization_config.is_checkpoint_fp8_serialized: if self.quantization_config.is_checkpoint_fp8_serialized:
weight_scale = layer.weight_scale.detach() weight_scale = layer.weight_scale.detach()
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly. # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
if _is_hip: if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale weight=weight, weight_scale=weight_scale
) )
...@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight.shape[-1] layer.weight, layer.weight.shape[-1]
) )
weight_scale = weight_scale.t().contiguous() weight_scale = weight_scale.t().contiguous()
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
else: else:
# if cutlass not supported, we fall back to use torch._scaled_mm # if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight # which requires per tensor quantization on weight
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype) qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
# Update the layer with the new values. # Update the layer with the new values.
...@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod: ...@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
): ):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
......
...@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size ...@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip from sglang.srt.utils import BumpAllocator, add_prefix
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
else:
from vllm._custom_ops import awq_dequantize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts ...@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant, block_quant_to_tensor_quant,
...@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8( per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
) )
q_nope_out = q_nope.new_empty( q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank) (self.num_local_heads, aligned_m, self.kv_lora_rank)
...@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8( per_token_group_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn attn_output.transpose(0, 1)
) )
) )
attn_bmm_output = attn_output.new_empty( attn_bmm_output = attn_output.new_empty(
......
...@@ -7,9 +7,9 @@ import torch ...@@ -7,9 +7,9 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
per_token_group_quant_fp8, per_token_group_quant_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
...@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase): ...@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
with torch.inference_mode(): with torch.inference_mode():
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12) ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8( out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8(
x, group_size x, group_size
) )
out = out[:, :num_tokens, :] out = out[:, :num_tokens, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment