"examples/pytorch/vscode:/vscode.git/clone" did not exist on "6bc8216118b64b77e042fc67757871d485ad5a72"
Unverified Commit 93cec433 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support new DeepGEMM (#7172)

parent ba589b88
......@@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE):
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
......@@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE):
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
down_input_scale
)
),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
......
......@@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8=use_fp8,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
round_scale=deep_gemm_wrapper.DEEPGEMM_V202506,
use_ue8m0=deep_gemm_wrapper.DEEPGEMM_V202506,
)
)
return packed_recv_hidden, packed_recv_count, event, hook
......
......@@ -12,6 +12,7 @@ import torch
import triton
import triton.language as tl
from sglang.math_utils import ceil_div
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
......@@ -518,10 +519,6 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
......
......@@ -21,6 +21,12 @@ def _compute_enable_deep_gemm():
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_V202506 = False
try:
from deep_gemm import fp8_gemm_nt
# They have not given a name to this breaking change
DEEPGEMM_V202506 = True
except ImportError:
DEEPGEMM_V202506 = False
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
......@@ -16,14 +16,24 @@ logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
if DEEPGEMM_V202506:
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import (
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
from deep_gemm import (
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
else:
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
def grouped_gemm_nt_f8f8bf16_masked(
......
......@@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs(
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
if As.dtype == torch.float:
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
elif Bs.dtype == torch.int:
assert (
triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
), f"{A.shape=} {As.shape=} {block_size=}"
else:
raise NotImplementedError
M = A.numel() // A.shape[-1]
......@@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs(
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
if Bs.dtype == torch.float:
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
elif Bs.dtype == torch.int:
assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
assert (
triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
), f"{B.shape=} {Bs.shape=} {block_size=}"
else:
raise NotImplementedError
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
......
......@@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
......
......@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
......@@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True
if False: # TODO (pr-chain)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
self._weight_requant_ue8m0()
def _weight_requant_ue8m0(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment