Unverified Commit 40652482 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

Support Llama4 fp8 inference (#5194)


Co-authored-by: default avatarlaixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 86a876d8
...@@ -342,6 +342,7 @@ def fused_moe_kernel( ...@@ -342,6 +342,7 @@ def fused_moe_kernel(
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr, even_Ks: tl.constexpr,
): ):
""" """
...@@ -416,20 +417,7 @@ def fused_moe_kernel( ...@@ -416,20 +417,7 @@ def fused_moe_kernel(
) )
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if use_int8_w8a8:
# block-wise # block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
...@@ -438,8 +426,7 @@ def fused_moe_kernel( ...@@ -438,8 +426,7 @@ def fused_moe_kernel(
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
) )
# channel-wise # channel-wise
else: elif per_channel_quant:
# Load per-column scale for weights
b_scale_ptrs = ( b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
) )
...@@ -447,6 +434,10 @@ def fused_moe_kernel( ...@@ -447,6 +434,10 @@ def fused_moe_kernel(
# Load per-token scale for activations # Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix. # Iterate to compute a block of the C matrix.
...@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel( ...@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
) -> None: ) -> None:
...@@ -777,10 +769,15 @@ def invoke_fused_moe_kernel( ...@@ -777,10 +769,15 @@ def invoke_fused_moe_kernel(
if block_shape is None: if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static # activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda: if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(A, A_scale) A, A_scale = sgl_scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else: else:
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale) A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else: else:
# activation block-wise fp8 quantization # activation block-wise fp8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
...@@ -796,6 +793,9 @@ def invoke_fused_moe_kernel( ...@@ -796,6 +793,9 @@ def invoke_fused_moe_kernel(
assert B_scale is not None assert B_scale is not None
if block_shape is None: if block_shape is None:
# activation channel-wise int8 quantization # activation channel-wise int8 quantization
assert (
per_channel_quant
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
A, A_scale = per_token_quant_int8(A) A, A_scale = per_token_quant_int8(A)
else: else:
# activation block-wise int8 quantization # activation block-wise int8 quantization
...@@ -904,6 +904,7 @@ def invoke_fused_moe_kernel( ...@@ -904,6 +904,7 @@ def invoke_fused_moe_kernel(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
even_Ks=even_Ks, even_Ks=even_Ks,
**config, **config,
) )
...@@ -1086,6 +1087,7 @@ def inplace_fused_experts( ...@@ -1086,6 +1087,7 @@ def inplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1107,6 +1109,7 @@ def inplace_fused_experts( ...@@ -1107,6 +1109,7 @@ def inplace_fused_experts(
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
per_channel_quant,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp, w1_zp,
...@@ -1129,6 +1132,7 @@ def inplace_fused_experts_fake( ...@@ -1129,6 +1132,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1160,6 +1164,7 @@ def outplace_fused_experts( ...@@ -1160,6 +1164,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1182,6 +1187,7 @@ def outplace_fused_experts( ...@@ -1182,6 +1187,7 @@ def outplace_fused_experts(
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
per_channel_quant,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp, w1_zp,
...@@ -1205,6 +1211,7 @@ def outplace_fused_experts_fake( ...@@ -1205,6 +1211,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1238,6 +1245,7 @@ def fused_experts( ...@@ -1238,6 +1245,7 @@ def fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1261,6 +1269,7 @@ def fused_experts( ...@@ -1261,6 +1269,7 @@ def fused_experts(
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
per_channel_quant,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp, w1_zp,
...@@ -1283,6 +1292,7 @@ def fused_experts( ...@@ -1283,6 +1292,7 @@ def fused_experts(
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
per_channel_quant,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp, w1_zp,
...@@ -1307,6 +1317,7 @@ def fused_experts_impl( ...@@ -1307,6 +1317,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1443,6 +1454,7 @@ def fused_experts_impl( ...@@ -1443,6 +1454,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
...@@ -1486,6 +1498,7 @@ def fused_experts_impl( ...@@ -1486,6 +1498,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
) )
...@@ -1532,6 +1545,7 @@ def fused_moe( ...@@ -1532,6 +1545,7 @@ def fused_moe(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None,
...@@ -1608,6 +1622,7 @@ def fused_moe( ...@@ -1608,6 +1622,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
w1_zp=w1_zp, w1_zp=w1_zp,
......
...@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: List[str], sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None, kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
packed_modules_mapping: Dict[str, List[str]] = {},
): ):
super().__init__() super().__init__()
self.ignore = ignore self.ignore = ignore
...@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list self.sparsity_ignore_list = sparsity_ignore_list
self.config = config self.config = config
self.packed_modules_mapping = packed_modules_mapping
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
...@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config config=config
) )
packed_modules_mapping = config.get("packed_modules_mapping", {})
return cls( return cls(
target_scheme_map=target_scheme_map, target_scheme_map=target_scheme_map,
...@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map, sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list, sparsity_ignore_list=sparsity_ignore_list,
config=config, config=config,
packed_modules_mapping=packed_modules_mapping,
) )
@classmethod @classmethod
......
...@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"input_activations" "input_activations"
) )
if not (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
f"{self.weight_quant}, {self.input_quant}"
)
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
def create_weights( def create_weights(
...@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # per-tensor quantization
# They will be combined to a single scale after weight loading. if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
w13_weight_scale = torch.nn.Parameter( # Allocate 2 scales for w1 and w3 respectively.
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False # They will be combined to a single scale after weight loading.
) w13_weight_scale = torch.nn.Parameter(
layer.register_parameter("w13_weight_scale", w13_weight_scale) torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
else:
raise ValueError(
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
)
w2_weight_scale = torch.nn.Parameter( layer.register_parameter("w13_weight_scale", w13_weight_scale)
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel) # Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly # to ensure the weight scales are loaded in properly
extra_weight_attrs.update( extra_weight_attrs.update({"quant_method": weight_quant_method})
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES # INPUT_SCALES
if self.static_input_scales: if self.static_input_scales:
assert (
self.input_quant.strategy == QuantizationStrategy.TENSOR
), "Only per-tensor quantization is supported for static input scales"
w13_input_scale = torch.nn.Parameter( w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False torch.ones(num_experts, dtype=torch.float32), requires_grad=False
) )
...@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False w2_input_scale, requires_grad=False
) )
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert. # We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts): for expert_id in range(layer.local_num_experts):
start = 0 start = 0
for shard_id in range(2): for shard_id in range(2):
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
)
if _is_cuda:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
else:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
) )
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
def apply( def apply(
self, self,
...@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
......
...@@ -217,6 +217,15 @@ def block_quant_to_tensor_quant( ...@@ -217,6 +217,15 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale return x_q_tensor, scale
def channel_quant_to_tensor_quant(
x_q_channel: torch.Tensor,
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
return x_q_tensor, scale
def apply_fp8_linear( def apply_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
......
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip, set_weight_attrs
_is_hip = is_hip() _is_hip = is_hip()
...@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig): ...@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method is_checkpoint_fp8_serialized = (
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
def get_quant_method( def get_quant_method(
...@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig): ...@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Fp8LinearMethod(self) return W8A8Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8FP8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs **extra_weight_attrs,
): ):
weight_dtype = ( weight_dtype = (
torch.float8_e4m3fn torch.float8_e4m3fn
...@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported, cutlass_fp8_supported=self.cutlass_fp8_supported,
) )
class W8A8FP8MoEMethod:
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
...@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod: ...@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod:
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
......
...@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__) ...@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
def _get_quantization_config( def _get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
) -> Optional[QuantizationConfig]: ) -> Optional[QuantizationConfig]:
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping
)
major, minor = get_device_capability() major, minor = get_device_capability()
if major is not None and minor is not None: if major is not None and minor is not None:
...@@ -142,7 +146,10 @@ def _initialize_model( ...@@ -142,7 +146,10 @@ def _initialize_model(
) -> nn.Module: ) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
quant_config = _get_quantization_config(model_config, load_config) packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
return model_class( return model_class(
config=model_config.hf_config, config=model_config.hf_config,
quant_config=quant_config, quant_config=quant_config,
......
...@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file( ...@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place. # TODO(woosuk): Move this to other place.
def get_quant_config( def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
) -> QuantizationConfig: ) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization) quant_cls = get_quantization_config(model_config.quantization)
...@@ -147,6 +149,7 @@ def get_quant_config( ...@@ -147,6 +149,7 @@ def get_quant_config(
# compressed-tensors uses a compressions_config # compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None) hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
if hf_quant_config is not None: if hf_quant_config is not None:
hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
return quant_cls.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model. # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
......
...@@ -55,6 +55,7 @@ from sglang.srt.layers.moe.topk import select_experts ...@@ -55,6 +55,7 @@ from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant, block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
...@@ -1411,27 +1412,34 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1411,27 +1412,34 @@ class DeepseekV2ForCausalLM(nn.Module):
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model. # This may affect the accuracy of fp8 model.
if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( if w.dtype in (
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz, torch.float8_e4m3fnuz,
): ):
weight_block_size = self.quant_config.weight_block_size if hasattr(self.quant_config, "weight_block_size"):
if weight_block_size is not None: weight_block_size = self.quant_config.weight_block_size
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") if weight_block_size is not None:
if _is_hip: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( if _is_hip:
weight=w, weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight_scale=self_attn.kv_b_proj.weight_scale_inv, weight=w,
input_scale=None, weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
) )
else: self_attn.w_scale = scale
weight = w else:
weight_scale = self_attn.kv_b_proj.weight_scale_inv weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
w, scale = block_quant_to_tensor_quant( w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale self_attn.w_scale = scale
if w.dtype == torch.int8: if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"): if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it # block-wise int8 need it
......
...@@ -414,7 +414,7 @@ class Llama4Model(nn.Module): ...@@ -414,7 +414,7 @@ class Llama4Model(nn.Module):
lambda idx, prefix: Llama4DecoderLayer( lambda idx, prefix: Llama4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
), ),
prefix="model.layers", prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -7,6 +7,7 @@ from torch import nn ...@@ -7,6 +7,7 @@ from torch import nn
from transformers import Llama4Config from transformers import Llama4Config
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -16,6 +17,7 @@ from sglang.srt.utils import add_prefix ...@@ -16,6 +17,7 @@ from sglang.srt.utils import add_prefix
class Llama4ForConditionalGeneration(nn.Module): class Llama4ForConditionalGeneration(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
} }
def __init__( def __init__(
...@@ -96,6 +98,15 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -96,6 +98,15 @@ class Llama4ForConditionalGeneration(nn.Module):
num_experts = self.config.text_config.num_local_experts num_experts = self.config.text_config.num_local_experts
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=num_experts,
)
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith( if name.startswith("vision_model") or name.startswith(
...@@ -115,31 +126,54 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -115,31 +126,54 @@ class Llama4ForConditionalGeneration(nn.Module):
break break
else: else:
if ".experts" in name: if ".experts" in name:
if ".gate_up_proj" in name: # NOTE: llama4 fp8 has different weight format for experts
name_list = [ if (
name.replace(".experts.gate_up_proj", ".experts.w13_weight") "experts.gate_up_proj" not in name
] * 2 and "experts.down_proj" not in name
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
): ):
param = params_dict[name] for mapping in expert_params_mapping:
weight_loader = param.weight_loader param_name, weight_name, expert_id, shard_id = mapping
for expert_id in range(num_experts): if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader( weight_loader(
param, param,
loaded_weight[expert_id].T, loaded_weight,
name, name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
break
else:
if ".gate_up_proj" in name:
name_list = [
name.replace(
".experts.gate_up_proj", ".experts.w13_weight"
)
] * 2
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
):
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(num_experts):
weight_loader(
param,
loaded_weight[expert_id].T,
name,
shard_id=shard_id,
expert_id=expert_id,
)
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -76,6 +76,7 @@ suites = { ...@@ -76,6 +76,7 @@ suites = {
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_hicache.py", 60), TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90), TestFile("test_hicache_mla.py", 90),
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 300), TestFile("models/lora/test_lora_tp.py", 300),
......
...@@ -124,6 +124,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase): ...@@ -124,6 +124,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
use_fp8_w8a8=False, # Not using fp8 use_fp8_w8a8=False, # Not using fp8
use_int8_w8a16=False, # Not using int8-w8a16 use_int8_w8a16=False, # Not using int8-w8a16
use_int8_w8a8=True, # Using int8-w8a8 use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s, w1_scale=w1_s,
w2_scale=w2_s, w2_scale=w2_s,
block_shape=None, # Not using block quantization block_shape=None, # Not using block quantization
......
import itertools
import unittest
import torch
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.test.test_utils import CustomTestCase
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def fp8_mask(a, mask):
dtype = a.dtype
return a.view(torch.int8)[mask].view(dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
fp8_mask(a_q, mask),
w1[i],
fp8_mask(a_s, mask),
w1_s[i],
output_dtype=a.dtype,
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = sgl_scaled_fp8_quant(
act_out, use_per_token_if_dynamic=True
)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
class TestW8A8FP8FusedMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33]
N = [128, 1024]
K = [256, 4096]
E = [8]
TOP_KS = [2, 6]
BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
# Initialize int8 quantization parameters
factor_for_scale = 1e-2
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = finfo.min
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True, # using fp8
use_int8_w8a16=False,
use_int8_w8a8=False,
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
)
# Check results
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.05
)
def test_w8a8_fp8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
):
self._w8a8_fp8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment