Unverified Commit ced3c07a authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support token-level quantization for EP MoE (#6782)

parent f18b068f
...@@ -178,6 +178,7 @@ def pre_reorder_triton_kernel( ...@@ -178,6 +178,7 @@ def pre_reorder_triton_kernel(
topk, topk,
hidden_size, hidden_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
): ):
OutDtype = gateup_input_ptr.dtype.element_ty OutDtype = gateup_input_ptr.dtype.element_ty
...@@ -188,11 +189,15 @@ def pre_reorder_triton_kernel( ...@@ -188,11 +189,15 @@ def pre_reorder_triton_kernel(
vec = tl.arange(0, BLOCK_SIZE) vec = tl.arange(0, BLOCK_SIZE)
if a1_scales_ptr is not None and use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
for idx in range(topk): for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx) expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id: if expert_id >= start_expert_id and expert_id <= end_expert_id:
if a1_scales_ptr is not None: if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) if not use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
else: else:
scale = 1.0 scale = 1.0
...@@ -558,6 +563,7 @@ def grouped_gemm_triton_kernel( ...@@ -558,6 +563,7 @@ def grouped_gemm_triton_kernel(
bs_stride_0: tl.constexpr, bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr, bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr, bs_stride_1: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
...@@ -621,7 +627,10 @@ def grouped_gemm_triton_kernel( ...@@ -621,7 +627,10 @@ def grouped_gemm_triton_kernel(
b_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None]) if use_per_token_if_dynamic:
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
else:
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id) scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value accumulator *= scale_a_value * scale_b_value
...@@ -658,6 +667,7 @@ def grouped_gemm_triton( ...@@ -658,6 +667,7 @@ def grouped_gemm_triton(
scale_b: torch.Tensor = None, scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
c_dtype=None, c_dtype=None,
use_per_token_if_dynamic: bool = True,
): ):
assert weight_column_major == True # TODO: more assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None: if use_fp8_w8a8 and block_shape is None:
...@@ -698,6 +708,11 @@ def grouped_gemm_triton( ...@@ -698,6 +708,11 @@ def grouped_gemm_triton(
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
) )
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
assert (
scale_a.shape[0] == a.shape[0]
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
grouped_gemm_triton_kernel[grid]( grouped_gemm_triton_kernel[grid](
a, a,
b, b,
...@@ -721,6 +736,7 @@ def grouped_gemm_triton( ...@@ -721,6 +736,7 @@ def grouped_gemm_triton(
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
use_per_token_if_dynamic,
**config, **config,
) )
return c return c
......
...@@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_quant_fp8,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
...@@ -65,10 +68,16 @@ logger = logging.getLogger(__name__) ...@@ -65,10 +68,16 @@ logger = logging.getLogger(__name__)
class GroupedGemmRunner(torch.nn.Module): class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None flashinfer_gemm_warpper = None
def __init__(self, device, use_flashinfer: bool = False): def __init__(
self,
device,
use_flashinfer: bool = False,
use_per_token_if_dynamic: bool = True,
):
super().__init__() super().__init__()
self.device = device self.device = device
self.use_flashinfer = use_flashinfer self.use_flashinfer = use_flashinfer
self.use_per_token_if_dynamic = use_per_token_if_dynamic
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
GroupedGemmRunner._init_flashinfer_wrapper(device) GroupedGemmRunner._init_flashinfer_wrapper(device)
...@@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_b, scale_b,
block_shape=block_shape, block_shape=block_shape,
c_dtype=c_dtype, c_dtype=c_dtype,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
) )
return c return c
...@@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module): ...@@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True,
): ):
super().__init__() super().__init__()
...@@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module): ...@@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.use_per_token_if_dynamic = use_per_token_if_dynamic
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
...@@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module): ...@@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = GroupedGemmRunner( self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer use_flashinfer=False, # TODO: use flashinfer
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
) )
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
...@@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module): ...@@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module):
), ),
) )
if self.activation_scheme == "dynamic" and not self.use_block_quant: if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = ( if self.use_per_token_if_dynamic:
torch.max(hidden_states) max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
.repeat(self.num_experts_per_partition) self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
.to(torch.float32) else:
) max_value = (
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# PreReorder # PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)]( pre_reorder_triton_kernel[(hidden_states.shape[0],)](
...@@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module): ...@@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module):
self.top_k, self.top_k,
hidden_states.shape[1], hidden_states.shape[1],
BLOCK_SIZE=512, BLOCK_SIZE=512,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
) )
dispose_tensor(hidden_states) dispose_tensor(hidden_states)
if (
self.activation_scheme == "dynamic"
and not self.use_block_quant
and self.use_per_token_if_dynamic
):
scale = torch.empty(
hidden_states_shape[0] * self.top_k,
device=hidden_states_device,
dtype=torch.float32,
)
scale[src2dst] = (
self.w13_input_scale.unsqueeze(1)
.expand(hidden_states_shape[0], self.top_k)
.reshape(-1)
)
self.w13_input_scale = scale
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange( weight_indices_cur_rank = torch.arange(
0, 0,
...@@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module): ...@@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module):
del gateup_input del gateup_input
# Act # Act
down_input = torch.empty( if self.activation_scheme == "dynamic" and not self.use_block_quant:
gateup_output.shape[0], self.w2_input_scale = None
gateup_output.shape[1] // 2, down_input = torch.empty(
device=gateup_output.device, gateup_output.shape[0],
dtype=( gateup_output.shape[1] // 2,
self.fp8_dtype device=gateup_output.device,
if (self.use_fp8_w8a8 and not self.use_block_quant) dtype=hidden_states_dtype,
else hidden_states_dtype )
), else:
) down_input = torch.empty(
if self.w2_input_scale is None and not self.use_block_quant: gateup_output.shape[0],
self.w2_input_scale = torch.ones( gateup_output.shape[1] // 2,
self.num_experts_per_partition, device=gateup_output.device,
dtype=torch.float32, dtype=(
device=hidden_states_device, self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states_dtype
),
) )
if self.activation == "silu": if self.activation == "silu":
...@@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module): ...@@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module):
raise ValueError(f"Unsupported activation: {self.activation=}") raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output del gateup_output
if self.activation_scheme == "dynamic" and not self.use_block_quant:
if self.use_per_token_if_dynamic:
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
else:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states_device,
)
# GroupGemm-1 # GroupGemm-1
down_output = torch.empty( down_output = torch.empty(
down_input.shape[0], down_input.shape[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment