Commit c0352f4a authored by lizhigong's avatar lizhigong Committed by maxiao1
Browse files

adapt w4a8 marlin deepep dp ep

parent 32b1ccaf
......@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
from sglang.srt import single_batch_overlap
......@@ -54,7 +55,286 @@ if _use_aiter:
logger = logging.getLogger(__name__)
class DeepEPMoE(FusedMoE):
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = True
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.use_w4a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(EPMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
......@@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE):
self.deepep_mode = get_deepep_mode()
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
......@@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE):
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w2_weight_scale
),
)
# elif not _is_npu:
# self.w13_weight_fp8 = (
# self.w13_weight,
# (
# self.w13_weight_scale_inv
# if self.use_block_quant
# else self.w13_weight_scale
# ),
# )
# self.w2_weight_fp8 = (
# self.w2_weight,
# (
# self.w2_weight_scale_inv
# if self.use_block_quant
# else self.w2_weight_scale
# ),
# )
def forward(
self,
......@@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE):
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
return self.forward_cutlass_w4afp8(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
......@@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE):
expert_mask=self.expert_mask,
)
def forward_deepgemm_w4a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None:
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
......
......@@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
overlap_args: Optional["CombineOverlapArgs"],
):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
else:
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
# else:
# if hidden_states.shape[0] > 0:
# num_tokens = self.src2dst.shape[0] // self.router_topk
# output = torch.empty(
# (num_tokens, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# deepep_post_reorder_triton_kernel[(num_tokens,)](
# hidden_states,
# output,
# self.src2dst,
# topk_idx,
# topk_weights,
# self.router_topk,
# hidden_states.shape[1],
# BLOCK_SIZE=512,
# )
# else:
# output = torch.zeros(
# (0, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment