Commit bd63af06 authored by maxiao1's avatar maxiao1
Browse files

Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'

适配w8a8_marlin 高吞吐模式

See merge request OpenDAS/sglang!25
parents 92f82dce eed591c9
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -39,7 +40,7 @@ if TYPE_CHECKING: ...@@ -39,7 +40,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
) )
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip() _is_hip = is_hip()
...@@ -127,6 +128,7 @@ class EPMoE(FusedMoE): ...@@ -127,6 +128,7 @@ class EPMoE(FusedMoE):
self.fp8_dtype = torch.float8_e4m3fn self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False self.use_w4a8_marlin = False
self.use_w8a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig): elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = ( self.block_shape = (
...@@ -137,12 +139,25 @@ class EPMoE(FusedMoE): ...@@ -137,12 +139,25 @@ class EPMoE(FusedMoE):
self.use_fp8_w8a8 = False self.use_fp8_w8a8 = False
self.activation_scheme = None self.activation_scheme = None
self.use_w4a8_marlin = True self.use_w4a8_marlin = True
self.use_w8a8_marlin = False
elif isinstance(quant_config, SlimQuantCompressedTensorsMarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = False
self.use_w8a8_marlin = True
else: else:
self.use_fp8_w8a8 = False self.use_fp8_w8a8 = False
self.use_block_quant = False self.use_block_quant = False
self.block_shape = None self.block_shape = None
self.activation_scheme = None self.activation_scheme = None
self.use_w4a8_marlin = False self.use_w4a8_marlin = False
self.use_w8a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
...@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE): ...@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE):
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if self.use_w4a8_marlin: if self.use_w4a8_marlin:
return self.forward_groupgemm_w4a8_marlin_masked(dispatch_output) return self.forward_groupgemm_w4a8_marlin_masked(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_masked(dispatch_output)
else: else:
if ( if (
get_moe_runner_backend().is_flashinfer_cutedsl() get_moe_runner_backend().is_flashinfer_cutedsl()
...@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE): ...@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE):
# base shapes # base shapes
num_groups, m, k = hidden_states.size() num_groups, m, k = hidden_states.size()
expected_m = m // 2 # 算子要求形状 expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ---- # ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states) q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
...@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE): ...@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE):
return down_output return down_output
def forward_groupgemm_w8a8_marlin_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
w2_weight = self.w2_weight
w2_scales = self.w2_weight_scale
n1 = w13_scales.size(1)
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked(
(q_a1_all, q_a1_scale),
(w13_weight, w13_scales),
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2_weight, w2_scales),
down_output,
masked_m,
expected_m,
)
return down_output
def forward_deepgemm_masked( def forward_deepgemm_masked(
self, self,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
......
...@@ -39,6 +39,18 @@ def get_w8a8_int8_marlin_weights( ...@@ -39,6 +39,18 @@ def get_w8a8_int8_marlin_weights(
return weight return weight
def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
k_tile=16,
n_tile=16, ):
assert w8a8_w.dtype == torch.int8, "w8a8_w 必须是 int8 类型"
size_n, size_k = w8a8_w.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
q = w8a8_w.reshape((size_n // n_tile, n_tile, size_k // k_tile, k_tile))
q = q.permute((0, 2, 1, 3)).contiguous()
q = q.reshape((size_n // k_tile, size_k * k_tile))
return q
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
...@@ -65,7 +77,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -65,7 +77,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
self.use_deepep = True
per_channel = ( per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN) and self.input_quant.strategy == QuantizationStrategy.TOKEN)
...@@ -138,13 +150,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -138,13 +150,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]): for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in) w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0) w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list = [] w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]): for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in) w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0) w2_marlin = torch.stack(w2_marlin_list, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment