Commit a0fb70e9 authored by lizhigong's avatar lizhigong
Browse files

adapt w4a8 marlin deepep dp ep

parent 848c5b82
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -124,7 +125,6 @@ class EPMoE(FusedMoE): ...@@ -124,7 +125,6 @@ class EPMoE(FusedMoE):
) )
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config): if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = ( self.block_shape = (
...@@ -135,11 +135,23 @@ class EPMoE(FusedMoE): ...@@ -135,11 +135,23 @@ class EPMoE(FusedMoE):
self.use_fp8_w8a8 = True self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = True
else: else:
self.use_fp8_w8a8 = False self.use_fp8_w8a8 = False
self.use_block_quant = False self.use_block_quant = False
self.block_shape = None self.block_shape = None
self.activation_scheme = None self.activation_scheme = None
self.use_w4a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
...@@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE): ...@@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE):
return_recv_hook=True, return_recv_hook=True,
) )
if self.deepep_mode.enable_low_latency() and not _is_npu: # if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm # # NPU supports low_latency deepep without deepgemm
assert ( # assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM # deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm" # ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter: if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1), # expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
...@@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE): ...@@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE):
) )
# the last one is invalid rank_id # the last one is invalid rank_id
self.expert_mask[:-1] = 1 self.expert_mask[:-1] = 1
elif not _is_npu: # elif not _is_npu:
self.w13_weight_fp8 = ( # self.w13_weight_fp8 = (
self.w13_weight, # self.w13_weight,
( # (
self.w13_weight_scale_inv # self.w13_weight_scale_inv
if self.use_block_quant # if self.use_block_quant
else self.w13_weight_scale # else self.w13_weight_scale
), # ),
) # )
self.w2_weight_fp8 = ( # self.w2_weight_fp8 = (
self.w2_weight, # self.w2_weight,
( # (
self.w2_weight_scale_inv # self.w2_weight_scale_inv
if self.use_block_quant # if self.use_block_quant
else self.w2_weight_scale # else self.w2_weight_scale
), # ),
) # )
def forward( def forward(
self, self,
...@@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE): ...@@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE):
assert DispatchOutputChecker.format_is_deepep(dispatch_output) assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output) return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 #assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output) if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl(): if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output) return self.forward_flashinfer_cutedsl(dispatch_output)
...@@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE): ...@@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE):
expert_mask=self.expert_mask, expert_mask=self.expert_mask,
) )
def forward_deepgemm_w4a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None:
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalOutput,
......
...@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
deepep_post_reorder_triton_kernel, deepep_post_reorder_triton_kernel,
) )
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: #if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states output = hidden_states
else: # else:
if hidden_states.shape[0] > 0: # if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk # num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty( # output = torch.empty(
(num_tokens, hidden_states.shape[1]), # (num_tokens, hidden_states.shape[1]),
device=hidden_states.device, # device=hidden_states.device,
dtype=hidden_states.dtype, # dtype=hidden_states.dtype,
) # )
deepep_post_reorder_triton_kernel[(num_tokens,)]( # deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states, # hidden_states,
output, # output,
self.src2dst, # self.src2dst,
topk_idx, # topk_idx,
topk_weights, # topk_weights,
self.router_topk, # self.router_topk,
hidden_states.shape[1], # hidden_states.shape[1],
BLOCK_SIZE=512, # BLOCK_SIZE=512,
) # )
else: # else:
output = torch.zeros( # output = torch.zeros(
(0, hidden_states.shape[1]), # (0, hidden_states.shape[1]),
device=hidden_states.device, # device=hidden_states.device,
dtype=hidden_states.dtype, # dtype=hidden_states.dtype,
) # )
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event return output, previous_event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment