Unverified Commit 53aed988 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

Refactor MoE (#2575)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 8a56b431
...@@ -94,7 +94,10 @@ class ModelConfig: ...@@ -94,7 +94,10 @@ class ModelConfig:
) )
# FIXME: temporary special judge for MLA architecture # FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures: if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
):
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank self.kv_lora_rank = self.hf_config.kv_lora_rank
......
...@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -628,8 +629,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -628,8 +629,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size if isinstance(param, BlockQuantScaleParameter):
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
) // tp_size
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight( param.load_merged_column_weight(
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
...@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
param.load_qkv_weight( param.load_qkv_weight(
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas, num_heads=self.num_kv_head_replicas,
......
...@@ -6,7 +6,7 @@ import functools ...@@ -6,7 +6,7 @@ import functools
import json import json
import logging import logging
import os import os
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
...@@ -14,6 +14,7 @@ import triton.language as tl ...@@ -14,6 +14,7 @@ import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name from sglang.srt.utils import direct_register_custom_op, get_device_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,8 +49,14 @@ def fused_moe_kernel( ...@@ -48,8 +49,14 @@ def fused_moe_kernel(
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_asm,
stride_ask,
stride_bse, stride_bse,
stride_bsk,
stride_bsn, stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
...@@ -133,8 +140,15 @@ def fused_moe_kernel( ...@@ -133,8 +140,15 @@ def fused_moe_kernel(
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr) if group_k > 0 and group_n > 0:
b_scale = tl.load(b_scale_ptr + off_experts) a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix. # Iterate to compute a block of the C matrix.
...@@ -165,7 +179,17 @@ def fused_moe_kernel( ...@@ -165,7 +179,17 @@ def fused_moe_kernel(
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8: elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
...@@ -178,7 +202,10 @@ def fused_moe_kernel( ...@@ -178,7 +202,10 @@ def fused_moe_kernel(
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8: elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type) if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else: else:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
...@@ -262,6 +289,7 @@ def invoke_fused_moe_kernel( ...@@ -262,6 +289,7 @@ def invoke_fused_moe_kernel(
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None,
) -> None: ) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -269,8 +297,16 @@ def invoke_fused_moe_kernel( ...@@ -269,8 +297,16 @@ def invoke_fused_moe_kernel(
padded_size = 0 padded_size = 0
if use_fp8_w8a8: if use_fp8_w8a8:
padded_size = padding_size padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16: elif use_int8_w8a16:
assert B_scale is not None assert B_scale is not None
else: else:
...@@ -309,8 +345,13 @@ def invoke_fused_moe_kernel( ...@@ -309,8 +345,13 @@ def invoke_fused_moe_kernel(
B.stride(1), B.stride(1),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
...@@ -415,6 +456,7 @@ def try_get_optimal_moe_config( ...@@ -415,6 +456,7 @@ def try_get_optimal_moe_config(
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
): ):
from sglang.srt.layers.moe.fused_moe_triton import get_config from sglang.srt.layers.moe.fused_moe_triton import get_config
...@@ -433,6 +475,13 @@ def try_get_optimal_moe_config( ...@@ -433,6 +475,13 @@ def try_get_optimal_moe_config(
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
# TODO(HandH1998): Optimize the configs of block-wise quant.
# NOTE(HandH1998): For block-wise quant,
# BLOCK_K must be divisable by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
return config return config
...@@ -464,6 +513,7 @@ def inplace_fused_experts( ...@@ -464,6 +513,7 @@ def inplace_fused_experts(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -478,6 +528,7 @@ def inplace_fused_experts( ...@@ -478,6 +528,7 @@ def inplace_fused_experts(
w2_scale, w2_scale,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape,
) )
...@@ -493,6 +544,7 @@ def inplace_fused_experts_fake( ...@@ -493,6 +544,7 @@ def inplace_fused_experts_fake(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> None: ) -> None:
pass pass
...@@ -517,6 +569,7 @@ def outplace_fused_experts( ...@@ -517,6 +569,7 @@ def outplace_fused_experts(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -531,6 +584,7 @@ def outplace_fused_experts( ...@@ -531,6 +584,7 @@ def outplace_fused_experts(
w2_scale, w2_scale,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape,
) )
...@@ -546,6 +600,7 @@ def outplace_fused_experts_fake( ...@@ -546,6 +600,7 @@ def outplace_fused_experts_fake(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -571,6 +626,7 @@ def fused_experts( ...@@ -571,6 +626,7 @@ def fused_experts(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
): ):
if inplace: if inplace:
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
...@@ -585,6 +641,7 @@ def fused_experts( ...@@ -585,6 +641,7 @@ def fused_experts(
w2_scale, w2_scale,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape,
) )
return hidden_states return hidden_states
else: else:
...@@ -600,6 +657,7 @@ def fused_experts( ...@@ -600,6 +657,7 @@ def fused_experts(
w2_scale, w2_scale,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape,
) )
...@@ -616,6 +674,7 @@ def fused_experts_impl( ...@@ -616,6 +674,7 @@ def fused_experts_impl(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
): ):
padded_size = padding_size padded_size = padding_size
if not use_fp8_w8a8: if not use_fp8_w8a8:
...@@ -647,6 +706,7 @@ def fused_experts_impl( ...@@ -647,6 +706,7 @@ def fused_experts_impl(
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1], topk_ids.shape[1],
config_dtype, config_dtype,
block_shape=block_shape,
) )
config = get_config_func(M) config = get_config_func(M)
...@@ -719,6 +779,7 @@ def fused_experts_impl( ...@@ -719,6 +779,7 @@ def fused_experts_impl(
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
) )
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
...@@ -740,6 +801,7 @@ def fused_experts_impl( ...@@ -740,6 +801,7 @@ def fused_experts_impl(
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
) )
torch.sum( torch.sum(
...@@ -768,6 +830,7 @@ def fused_moe( ...@@ -768,6 +830,7 @@ def fused_moe(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -795,6 +858,12 @@ def fused_moe( ...@@ -795,6 +858,12 @@ def fused_moe(
w1. w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2. w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -826,4 +895,5 @@ def fused_moe( ...@@ -826,4 +895,5 @@ def fused_moe(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape,
) )
...@@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum): ...@@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
CHANNEL = "channel" CHANNEL = "channel"
GROUP = "group" GROUP = "group"
BLOCK = "block"
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
...@@ -214,6 +215,7 @@ class FusedMoE(torch.nn.Module): ...@@ -214,6 +215,7 @@ class FusedMoE(torch.nn.Module):
) )
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
...@@ -470,7 +472,10 @@ class FusedMoE(torch.nn.Module): ...@@ -470,7 +472,10 @@ class FusedMoE(torch.nn.Module):
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank, tp_rank=tp_rank,
) )
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale( self._load_model_weight_or_group_weight_scale(
shard_id=shard_id, shard_id=shard_id,
shard_dim=shard_dim, shard_dim=shard_dim,
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...@@ -32,7 +33,11 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -32,7 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_w8a8_block_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
...@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig): ...@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized: bool = False, is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic", activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None, ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
) -> None: ) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
...@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig): ...@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig):
raise ValueError(f"Unsupported activation scheme {activation_scheme}") raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or [] self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
if not is_checkpoint_fp8_serialized:
raise ValueError(
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
)
if len(weight_block_size) != 2:
raise ValueError(
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
)
if activation_scheme != "dynamic":
raise ValueError(
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
)
self.weight_block_size = weight_block_size
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig): ...@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
return cls( return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme, activation_scheme=activation_scheme,
ignored_layers=ignored_layers, ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
) )
def get_quant_method( def get_quant_method(
...@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
if is_hip(): if is_hip():
self.use_marlin = False self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# Required by row parallel
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
if input_size_per_partition % block_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# Required by collum parallel or enabling merged weights
if (
tp_size > 1 and output_size // output_size_per_partition == tp_size
) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
...@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading. # Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
scale = PerTensorScaleParameter( if self.block_quant:
data=torch.empty(len(output_partition_sizes), dtype=torch.float32), assert self.quant_config.activation_scheme == "dynamic"
weight_loader=weight_loader, scale = BlockQuantScaleParameter(
) data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
scale[:] = torch.finfo(torch.float32).min (input_size_per_partition + block_k - 1) // block_k,
layer.register_parameter("weight_scale", scale) dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale_inv", scale)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
...@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
...@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
if self.block_quant:
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
return apply_fp8_linear( return apply_fp8_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
...@@ -339,6 +418,7 @@ class Fp8MoEMethod: ...@@ -339,6 +418,7 @@ class Fp8MoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights( def create_weights(
self, self,
...@@ -353,6 +433,28 @@ class Fp8MoEMethod: ...@@ -353,6 +433,28 @@ class Fp8MoEMethod:
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
...@@ -374,21 +476,45 @@ class Fp8MoEMethod: ...@@ -374,21 +476,45 @@ class Fp8MoEMethod:
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. if self.block_quant:
# They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter(
w13_weight_scale = torch.nn.Parameter( torch.ones(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False num_experts,
) 2 * ((intermediate_size + block_n - 1) // block_n),
layer.register_parameter("w13_weight_scale", w13_weight_scale) (hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
w2_weight_scale = torch.nn.Parameter( ),
torch.ones(num_experts, dtype=torch.float32), requires_grad=False requires_grad=False,
) )
layer.register_parameter("w2_weight_scale", w2_weight_scale) w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel) # Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly # to ensure the weight scales are loaded in properly
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
) )
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
...@@ -422,7 +548,9 @@ class Fp8MoEMethod: ...@@ -422,7 +548,9 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
return
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW) # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
...@@ -519,7 +647,6 @@ class Fp8MoEMethod: ...@@ -519,7 +647,6 @@ class Fp8MoEMethod:
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False w2_input_scale, requires_grad=False
) )
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert. # We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
...@@ -594,10 +721,17 @@ class Fp8MoEMethod: ...@@ -594,10 +721,17 @@ class Fp8MoEMethod:
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale, w1_scale=(
w2_scale=layer.w2_weight_scale, layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
) )
......
from typing import List, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
# TODO(HandH1998):
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
# BLOCK_SIZE_K must be divisable by block_k
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
BLOCK_SIZE_M = 128
if M < BLOCK_SIZE_M:
BLOCK_SIZE_M = triton.next_power_of_2(M)
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
BLOCK_SIZE_K = block_k
assert block_k % BLOCK_SIZE_K == 0
BLOCK_SIZE_N = block_n
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
)
return C
from typing import Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
...@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz( ...@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
if input_scale is not None: if input_scale is not None:
input_scale = input_scale * 2.0 input_scale = input_scale * 2.0
return weight, weight_scale, input_scale return weight, weight_scale, input_scale
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_float8(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [
[
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
...@@ -43,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -43,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
input_to_float8,
)
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
...@@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: ...@@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
def input_to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
class DeepseekV2Attention(nn.Module): class DeepseekV2Attention(nn.Module):
def __init__( def __init__(
...@@ -869,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -869,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
# TODO(HandH1998): Modify it when nextn is supported.
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -933,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -933,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
).T ).T
else: else:
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if (
hasattr(self.quant_config, "weight_block_size")
and w.dtype == torch.float8_e4m3fn
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
w, scale = block_quant_to_tensor_quant(
w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
)
self_attn.w_scale = scale
w_kc, w_vc = w.unflatten( w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"): if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale self_attn.w_scale = self_attn.kv_b_proj.weight_scale
EntryClass = DeepseekV2ForCausalLM class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
import itertools
import unittest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
# For test
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 256, 512]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)
with torch.inference_mode():
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
)
self.assertTrue(torch.allclose(scale, ref_scale))
def test_per_token_group_quant_fp8(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.GROUP_SIZE,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
dtype=params[2],
group_size=params[3],
seed=params[4],
):
self._per_token_group_quant_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
class TestW8A8BlockFP8Matmul(unittest.TestCase):
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 7, 83, 512, 2048]
N = [128, 512, 1024, 4096, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
with torch.inference_mode():
ref_out = native_w8a8_block_fp8_matmul(
A_fp8, B_fp8, As, Bs, block_size, out_dtype
)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.001
)
def test_w8a8_block_fp8_matmul(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.BLOCK_SIZE,
self.OUT_DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
block_size=params[3],
out_dtype=params[4],
seed=params[5],
):
self._w8a8_block_fp8_matmul(*params)
# For test
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_fp8_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_fp8_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 33, 64, 222, 1024 * 128]
N = [128, 1024, 2048]
K = [256, 4096, 5120]
E = [8, 24]
TOP_KS = [2, 6]
BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
# BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.02
)
def test_w8a8_block_fp8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
):
self._w8a8_block_fp8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment