Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
...@@ -334,6 +334,8 @@ class RoutingSimulatorRouter(BaseRouter): ...@@ -334,6 +334,8 @@ class RoutingSimulatorRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Use routing simulator to compute routing.""" """Use routing simulator to compute routing."""
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
......
...@@ -72,6 +72,8 @@ class ZeroExpertRouter(BaseRouter): ...@@ -72,6 +72,8 @@ class ZeroExpertRouter(BaseRouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing with full bias, compute zero expert output, """Compute routing with full bias, compute zero expert output,
mask zero expert IDs.""" mask zero expert IDs."""
......
...@@ -91,6 +91,7 @@ def _moe_forward( ...@@ -91,6 +91,7 @@ def _moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> torch.Tensor: ) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name)) layer = get_layer_from_name(_resolve_layer_name(layer_name))
...@@ -99,6 +100,7 @@ def _moe_forward( ...@@ -99,6 +100,7 @@ def _moe_forward(
hidden_states, hidden_states,
router_logits, router_logits,
shared_experts_input, shared_experts_input,
input_ids,
) )
...@@ -106,6 +108,7 @@ def _moe_forward_fake( ...@@ -106,6 +108,7 @@ def _moe_forward_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -115,6 +118,7 @@ def _moe_forward_shared( ...@@ -115,6 +118,7 @@ def _moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name)) layer = get_layer_from_name(_resolve_layer_name(layer_name))
...@@ -123,6 +127,7 @@ def _moe_forward_shared( ...@@ -123,6 +127,7 @@ def _moe_forward_shared(
hidden_states, hidden_states,
router_logits, router_logits,
shared_experts_input, shared_experts_input,
input_ids,
) )
...@@ -130,6 +135,7 @@ def _moe_forward_shared_fake( ...@@ -130,6 +135,7 @@ def _moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes: # Output shapes:
...@@ -433,6 +439,7 @@ class MoERunner(MoERunnerInterface): ...@@ -433,6 +439,7 @@ class MoERunner(MoERunnerInterface):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor]: ) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Run expert routing and the fused MoE kernel via the quant method. """Run expert routing and the fused MoE kernel via the quant method.
...@@ -449,11 +456,13 @@ class MoERunner(MoERunnerInterface): ...@@ -449,11 +456,13 @@ class MoERunner(MoERunnerInterface):
layer=layer, layer=layer,
x=hidden_states, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
input_ids=input_ids,
) )
else: else:
topk_weights, topk_ids = self.router.select_experts( topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
input_ids=input_ids,
) )
# Passing shared_experts_input in case SharedExpertsOrder is # Passing shared_experts_input in case SharedExpertsOrder is
...@@ -523,6 +532,7 @@ class MoERunner(MoERunnerInterface): ...@@ -523,6 +532,7 @@ class MoERunner(MoERunnerInterface):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Invoke the fused moe layer. """Invoke the fused moe layer.
...@@ -565,6 +575,7 @@ class MoERunner(MoERunnerInterface): ...@@ -565,6 +575,7 @@ class MoERunner(MoERunnerInterface):
hidden_states, hidden_states,
router_logits, router_logits,
shared_experts_input, shared_experts_input,
input_ids,
self._encode_layer_name(), self._encode_layer_name(),
) )
...@@ -672,6 +683,7 @@ class MoERunner(MoERunnerInterface): ...@@ -672,6 +683,7 @@ class MoERunner(MoERunnerInterface):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Entry point called by the custom op to run the MoE computation. """Entry point called by the custom op to run the MoE computation.
...@@ -712,6 +724,7 @@ class MoERunner(MoERunnerInterface): ...@@ -712,6 +724,7 @@ class MoERunner(MoERunnerInterface):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
shared_experts_input=shared_experts_input, shared_experts_input=shared_experts_input,
input_ids=input_ids,
) )
return self._maybe_combine( return self._maybe_combine(
......
...@@ -26,6 +26,7 @@ class MoERunnerInterface(ABC): ...@@ -26,6 +26,7 @@ class MoERunnerInterface(ABC):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
......
...@@ -309,6 +309,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -309,6 +309,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
if self.unquantized_backend == UnquantizedMoeBackend.CPU: if self.unquantized_backend == UnquantizedMoeBackend.CPU:
......
...@@ -4,6 +4,7 @@ import functools ...@@ -4,6 +4,7 @@ import functools
from math import prod from math import prod
import torch import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -384,3 +385,20 @@ def trtllm_moe_pack_topk_ids_weights( ...@@ -384,3 +385,20 @@ def trtllm_moe_pack_topk_ids_weights(
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view( return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16 torch.int16
) )
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def swiglu_limit_func(
output: torch.Tensor,
input: torch.Tensor, # first half is gate, second half is up
swiglu_limit: float = 0.0,
) -> None:
d = input.shape[1] // 2
gate = input[:, :d]
up = input[:, d:]
if swiglu_limit > 0:
gate = torch.clamp(gate, max=swiglu_limit)
up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit)
output.copy_(F.silu(gate) * up)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from functools import cache
from typing import TYPE_CHECKING
import torch
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
# tilelang is only available on CUDA platforms
if TYPE_CHECKING or current_platform.is_cuda_alike():
if not has_tilelang():
raise ImportError(
"tilelang is required for mhc but is not installed. Install it with "
"`pip install tilelang`."
)
import tilelang
import tilelang.language as T
else:
tilelang = None # type: ignore[assignment]
T = None # type: ignore[assignment]
@cache
def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
device_props = torch.cuda.get_device_properties(0)
n_sms = device_props.multi_processor_count
split_k = n_sms // grid_size
if k is not None:
# avoid split_k for small k
num_block_k = cdiv(k, block_k)
split_k = min(split_k, num_block_k // 4)
split_k = max(split_k, 1)
return split_k
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_pre_big_fuse_tilelang(
gemm_out_mul,
gemm_out_sqrsum,
hc_scale,
hc_base,
residual,
post_mix,
comb_mix,
layer_input,
hidden_size: int,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 16,
hc_mult: int = 4,
):
"""Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block."""
num_tokens = T.dynamic("num_tokens")
hc_mult3 = hc_mult * (2 + hc_mult)
hidden_block = math.gcd(512, hidden_size)
gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] # type: ignore[no-redef, valid-type]
hc_scale: T.Tensor[[3], T.float32] # type: ignore[no-redef, valid-type]
hc_base: T.Tensor[[hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
# outputs
post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] # type: ignore[no-redef, valid-type]
comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] # type: ignore[no-redef, valid-type]
layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
with T.Kernel(num_tokens, threads=96) as i:
T.pdl_sync()
##################################################################
# _pre_norm_fn_fwd_norm
rms = T.alloc_fragment(1, T.float32)
mixes = T.alloc_fragment(hc_mult3, T.float32)
T.clear(mixes)
rms[0] = 0
for i_split in T.serial(n_splits):
rms[0] += gemm_out_sqrsum[i_split, i]
rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps)
for j in T.Parallel(hc_mult3):
mixes[j] = 0
for i_split in T.serial(n_splits):
mixes[j] += gemm_out_mul[i_split, i, j]
mixes[j] *= rms[0]
mixes_shared = T.alloc_shared(hc_mult3, T.float32)
T.copy(mixes, mixes_shared)
if T.get_thread_binding() < 32:
##################################################################
# _pre_split_mixes_fwd (post & comb)
cm = T.alloc_fragment((hc_mult, hc_mult), T.float32)
for j in T.Parallel(hc_mult):
post_mix[i, j] = (
T.sigmoid(
mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]
)
* hc_post_mult_value
)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = (
mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2]
+ hc_base[j * hc_mult + k + hc_mult * 2]
)
##################################################################
# _sinkhorn_fwd
row_sum = T.alloc_fragment(hc_mult, T.float32)
col_sum = T.alloc_fragment(hc_mult, T.float32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(hc_mult, T.float32)
T.reduce_max(cm, row_max, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = T.exp(cm[j, k] - row_max[j])
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
for _ in T.serial(sinkhorn_repeat - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
# save comb_mix to global memory
for j, k in T.Parallel(hc_mult, hc_mult):
comb_mix[i, j * hc_mult + k] = cm[j, k]
else:
##################################################################
# _pre_split_mixes_fwd (pre)
pre_mix_shared = T.alloc_shared(hc_mult, T.float32)
for j in T.Parallel(hc_mult):
pre_mix_shared[j] = (
T.sigmoid(
mixes_shared[j] * hc_scale[0] + hc_base[j],
)
+ hc_pre_eps
)
###################################################################
# _pre_apply_mix_fwd
for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2):
xs = T.alloc_shared((hc_mult, hidden_block), T.float32)
xl = T.alloc_fragment((hc_mult, hidden_block), T.float32)
T.copy(residual[i, 0, i0_h * hidden_block], xs)
T.copy(xs, xl)
ol = T.alloc_fragment(hidden_block, T.float32)
T.clear(ol)
for i_hc in T.serial(hc_mult):
pre = pre_mix_shared[i_hc]
for i1_h in T.Parallel(hidden_block):
ol[i1_h] += pre * xl[i_hc, i1_h]
T.copy(ol, layer_input[i, i0_h * hidden_block])
T.pdl_trigger()
def mhc_pre(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for mHC pre block.
Args:
residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
hc_scale: shape (3,), dtype torch.float32
hc_base: shape (hc_mult3,), dtype torch.float32
rms_eps: RMS normalization epsilon
hc_pre_eps: pre-mix epsilon
hc_sinkhorn_eps: sinkhorn epsilon
hc_post_mult_value: post-mix multiplier value
sinkhorn_repeat: number of sinkhorn iterations
n_splits: split-k factor;
Returns:
post_mix: shape (..., hc_mult), dtype torch.float32
comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
layer_input: shape (..., hidden_size), dtype torch.bfloat16
"""
# Validate shapes
assert residual.dtype == torch.bfloat16
assert fn.dtype == torch.float32
assert hc_scale.dtype == torch.float32
assert hc_base.dtype == torch.float32
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
hc_mult2 = hc_mult * hc_mult
hc_mult3 = hc_mult * 2 + hc_mult2
hc_hidden_size = hc_mult * hidden_size
assert fn.shape[0] == hc_mult3
assert fn.shape[1] == hc_hidden_size
assert hc_scale.shape == (3,)
assert hc_base.shape == (hc_mult3,)
outer_shape = residual.shape[:-2]
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
fn_flat = fn
# these number are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
post_mix = torch.empty(
num_tokens,
hc_mult,
dtype=torch.float32,
device=residual.device,
)
comb_mix = torch.empty(
num_tokens,
hc_mult2,
dtype=torch.float32,
device=residual.device,
)
layer_input = torch.empty(
num_tokens,
hidden_size,
dtype=torch.bfloat16,
device=residual.device,
)
gemm_out_mul = torch.empty(
n_splits,
num_tokens,
hc_mult3,
dtype=torch.float32,
device=residual.device,
)
gemm_out_sqrsum = torch.empty(
n_splits,
num_tokens,
dtype=torch.float32,
device=residual.device,
)
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
tf32_hc_prenorm_gemm(
residual_flat.view(num_tokens, hc_mult * hidden_size),
fn_flat,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
mhc_pre_big_fuse_tilelang(
gemm_out_mul,
gemm_out_sqrsum,
hc_scale,
hc_base,
residual_flat,
post_mix,
comb_mix,
layer_input,
hidden_size,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
hc_mult,
)
post_mix = post_mix.view(*outer_shape, hc_mult, 1)
comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult)
layer_input = layer_input.view(*outer_shape, hidden_size)
return post_mix, comb_mix, layer_input
def _mhc_pre_fake(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
# Create empty tensors with correct shapes for meta device / shape inference
post_mix = torch.empty(
*outer_shape,
hc_mult,
1,
dtype=torch.float32,
device=residual.device,
)
comb_mix = torch.empty(
*outer_shape,
hc_mult,
hc_mult,
dtype=torch.float32,
device=residual.device,
)
layer_input = torch.empty(
*outer_shape,
hidden_size,
dtype=torch.bfloat16,
device=residual.device,
)
return post_mix, comb_mix, layer_input
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_post_tilelang(
a,
b,
c,
d,
x,
hc: int,
hidden: int,
n_thr: int = 128,
h_blk: int = 1024,
) -> tilelang.JITKernel:
# rename for shorter code
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
a: T.Tensor((n, hc, hc), T.float32) # type: ignore[no-redef, valid-type]
b: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
c: T.Tensor((n, hc), T.float32) # type: ignore[no-redef, valid-type]
d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type]
x: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
with T.Kernel(n, threads=n_thr) as i_n:
x_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
b_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
d_shared = T.alloc_shared(h_blk, T.bfloat16)
x_local = T.alloc_fragment((hc, h_blk), T.float32)
b_local = T.alloc_fragment((hc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((hc, hc), T.float32)
c_local = T.alloc_fragment(hc, T.float32)
T.pdl_sync()
T.copy(a[i_n, 0, 0], a_local)
T.copy(c[i_n, 0], c_local)
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
T.copy(b[i_n, 0, i0_h * h_blk], b_shared)
T.copy(d[i_n, i0_h * h_blk], d_shared)
T.copy(b_shared, b_local)
T.copy(d_shared, d_local)
for i_hco, i1_h in T.Parallel(hc, h_blk):
x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h]
for i_hci in T.serial(hc):
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
T.copy(x_local, x_shared)
T.copy(x_shared, x[i_n, 0, i0_h * h_blk])
T.pdl_trigger()
def mhc_post(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
out = torch.empty_like(residual)
mhc_post_tilelang(
comb_res_mix,
residual,
post_layer_mix.squeeze(-1),
x,
out,
residual.shape[-2],
residual.shape[-1],
)
return out
def _mhc_post_fake(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(residual)
direct_register_custom_op(
op_name="mhc_pre",
op_func=mhc_pre,
mutates_args=[],
fake_impl=_mhc_pre_fake,
)
direct_register_custom_op(
op_name="mhc_post",
op_func=mhc_post,
mutates_args=[],
fake_impl=_mhc_post_fake,
)
...@@ -32,6 +32,7 @@ QuantizationMethods = Literal[ ...@@ -32,6 +32,7 @@ QuantizationMethods = Literal[
"inc", "inc",
"mxfp4", "mxfp4",
"gpt_oss_mxfp4", "gpt_oss_mxfp4",
"deepseek_v4_fp8",
"cpu_awq", "cpu_awq",
"online", "online",
# Below are values of the OnlineQuantScheme enum, specified as strings to # Below are values of the OnlineQuantScheme enum, specified as strings to
...@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early # lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config
from .awq import AWQConfig from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig from .awq_marlin import AWQMarlinConfig
...@@ -163,6 +165,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -163,6 +165,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig, "inc": INCConfig,
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config, "gpt_oss_mxfp4": GptOssMxfp4Config,
"deepseek_v4_fp8": DeepseekV4FP8Config,
"cpu_awq": CPUAWQConfig, "cpu_awq": CPUAWQConfig,
"humming": HummingConfig, "humming": HummingConfig,
"online": OnlineQuantizationConfig, "online": OnlineQuantizationConfig,
......
...@@ -265,6 +265,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -265,6 +265,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
......
...@@ -305,6 +305,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -305,6 +305,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in ( assert layer.activation in (
......
...@@ -367,6 +367,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -367,6 +367,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
......
...@@ -168,6 +168,7 @@ class CompressedTensorsW8A8Mxfp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -168,6 +168,7 @@ class CompressedTensorsW8A8Mxfp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
......
...@@ -517,6 +517,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -517,6 +517,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.kernel_backend == "Flashinfer" assert self.kernel_backend == "Flashinfer"
return flashinfer_trtllm_mxint4_moe( return flashinfer_trtllm_mxint4_moe(
......
...@@ -269,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -269,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.is_scale_e8m0 = getattr(quant_config, "is_scale_e8m0", False)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype self.input_dtype = get_current_vllm_config().model_config.dtype
...@@ -362,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -362,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase):
input_size_per_partition, input_size_per_partition,
self.weight_block_size, self.weight_block_size,
weight_loader, weight_loader,
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
) )
# The weight_scale_inv name is intentional for deepseekv3 # The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale) layer.register_parameter("weight_scale_inv", scale)
...@@ -866,6 +868,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -866,6 +868,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
......
...@@ -950,6 +950,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -950,6 +950,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
...@@ -1442,6 +1443,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1442,6 +1443,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
...@@ -1920,6 +1922,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): ...@@ -1920,6 +1922,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from flashinfer.fused_moe.core import ( from flashinfer.fused_moe.core import (
ActivationType, ActivationType,
......
...@@ -20,10 +20,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( ...@@ -20,10 +20,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS, TRITON_BACKENDS,
Mxfp4MoeBackend, Mxfp4MoeBackend,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format, convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
convert_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel, make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size, mxfp4_round_up_hidden_size_and_intermediate_size,
select_gpt_oss_mxfp4_moe_backend, select_gpt_oss_mxfp4_moe_backend,
select_mxfp4_moe_backend,
) )
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -217,6 +219,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase): ...@@ -217,6 +219,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_scale.quant_method = "block"
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
...@@ -242,6 +245,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase): ...@@ -242,6 +245,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
) )
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale.quant_method = "block"
if self.moe.has_bias: if self.moe.has_bias:
w13_bias = torch.nn.Parameter( w13_bias = torch.nn.Parameter(
...@@ -397,6 +401,9 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase): ...@@ -397,6 +401,9 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
w2_scale=w2_scale, w2_scale=w2_scale,
w1_bias=w1_bias, w1_bias=w1_bias,
w2_bias=w2_bias, w2_bias=w2_bias,
gemm1_alpha=1.702,
gemm1_beta=1.0,
swiglu_limit=7.0,
) )
def select_gemm_impl( def select_gemm_impl(
...@@ -437,6 +444,332 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase): ...@@ -437,6 +444,332 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
router_logits=router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class Mxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
self.moe_kernel: mk.FusedMoEKernel | None = None
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
@property
def skip_forward_padding(self) -> bool:
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
# so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
return mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
mxfp4_block = 32
layer.params_dtype = params_dtype
layer.num_experts = num_experts
self.intermediate_size = intermediate_size_per_partition
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_scale.quant_method = "block"
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale.quant_method = "block"
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _setup_kernel(
self,
layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
num_experts = self.num_experts
intermediate_size = self.intermediate_size
hidden_size = self.hidden_size
sf_block_size = 32
# Shape assertions
assert (
w13.dim() == 3
and w13.shape[0] == num_experts
and w13.shape[1] == intermediate_size * 2
and w13.shape[2] == hidden_size // 2
)
assert (
w13_scale.dim() == 3
and w13_scale.shape[0] == num_experts
and w13_scale.shape[1] == intermediate_size * 2
and w13_scale.shape[2] == hidden_size // sf_block_size
)
assert (
w2.dim() == 3
and w2.shape[0] == num_experts
and w2.shape[1] == hidden_size
and w2.shape[2] == intermediate_size // 2
)
assert (
w2_scale.dim() == 3
and w2_scale.shape[1] == hidden_size
and w2_scale.shape[2] == intermediate_size // sf_block_size
)
if w13_bias is not None:
assert (
w13_bias.dim() == 2
and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2
)
if w2_bias is not None:
assert (
w2_bias.dim() == 2
and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size
)
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
w13_weight_scale=w13_scale,
w2_weight_scale=w2_scale,
w13_bias=w13_bias,
w2_bias=w2_bias,
_cache_permute_indices=self._cache_permute_indices,
)
)
# For TRITON backends, weights are wrapped tensors from triton_kernels
# that don't support .detach(). Manually assign parameters.
if self.mxfp4_backend not in TRITON_BACKENDS:
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
else:
layer.w13_weight = w13
layer.w2_weight = w2
self.w13_precision_config = w13_scale
self.w2_precision_config = w2_scale
if w13_bias is not None and w2_bias is not None:
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
# Build quant config
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
# Build kernel (modular or monolithic)
if self.moe_quant_config is not None and self.experts_cls is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
mxfp4_backend=self.mxfp4_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer):
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
return
self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w1_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
swiglu_limit = getattr(layer, "swiglu_limit", None)
if self.mxfp4_backend in TRITON_BACKENDS:
assert self.w13_precision_config is not None
assert self.w2_precision_config is not None
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
swiglu_limit=swiglu_limit,
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
......
...@@ -130,6 +130,7 @@ class OnlineMoEMethodBase(FusedMoEMethodBase): ...@@ -130,6 +130,7 @@ class OnlineMoEMethodBase(FusedMoEMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
......
...@@ -1457,6 +1457,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1457,6 +1457,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
......
...@@ -149,6 +149,148 @@ def _per_token_group_quant_fp8( ...@@ -149,6 +149,148 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
@triton.jit
def _silu_mul_quant_fp8_packed_kernel(
input_ptr,
output_q_ptr,
output_scale_ptr,
M,
input_stride_m,
output_q_stride_m,
output_scale_stride_k,
clamp_limit,
N: tl.constexpr,
NUM_GROUPS: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
HAS_CLAMP: tl.constexpr,
):
N_2: tl.constexpr = N // 2
pid_pack = tl.program_id(0)
pid_m = tl.program_id(1)
m_offset = pid_m * BLOCK_M
if m_offset >= M:
return
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, GROUP_SIZE)
row_mask = (m_offset + offs_m) < M
base_row_offset = (m_offset + offs_m[:, None]) * input_stride_m
base_out_offset = (m_offset + offs_m[:, None]) * output_q_stride_m
packed_scale = tl.zeros((BLOCK_M,), dtype=tl.int32)
for pack_idx in tl.static_range(4):
group_id = pid_pack * 4 + pack_idx
if group_id < NUM_GROUPS:
n_offset = group_id * GROUP_SIZE
act_ptrs = input_ptr + base_row_offset + n_offset + offs_n[None, :]
act_in = tl.load(act_ptrs, mask=row_mask[:, None], other=0.0)
mul_ptrs = act_ptrs + N_2
mul_in = tl.load(mul_ptrs, mask=row_mask[:, None], other=0.0)
act_f32 = act_in.to(tl.float32)
mul_f32 = mul_in.to(tl.float32)
if HAS_CLAMP:
act_f32 = tl.minimum(act_f32, clamp_limit)
mul_f32 = tl.clamp(mul_f32, -clamp_limit, clamp_limit)
y = (act_f32 / (1.0 + tl.exp(-act_f32))) * mul_f32
# Round through bf16 to match unfused precision path
y = y.to(tl.bfloat16).to(tl.float32)
absmax = tl.max(tl.abs(y), axis=1)
scale_raw = tl.maximum(absmax / fp8_max, 1e-10)
exponent = tl.ceil(tl.log2(scale_raw))
scale = tl.math.exp2(exponent)
y_q = tl.clamp(y / scale[:, None], fp8_min, fp8_max)
out_q_ptrs = output_q_ptr + base_out_offset + n_offset + offs_n[None, :]
tl.store(
out_q_ptrs,
y_q.to(output_q_ptr.dtype.element_ty),
mask=row_mask[:, None],
)
exponent_biased = tl.clamp(exponent + 127.0, 0.0, 255.0).to(tl.int32)
packed_scale = packed_scale | (exponent_biased << (pack_idx * 8))
scale_ptrs = output_scale_ptr + pid_pack * output_scale_stride_k + m_offset + offs_m
tl.store(scale_ptrs, packed_scale, mask=row_mask)
def silu_mul_quant_fp8_packed_triton(
input: torch.Tensor,
group_size: int = 128,
output_q: torch.Tensor | None = None,
clamp_limit: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.dim() == 2
assert input.is_contiguous()
M, N = input.shape
N_2 = N // 2
assert N_2 % group_size == 0
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_min, fp8_max = finfo.min, finfo.max
num_groups_per_row = N_2 // group_size
num_packed_groups = (num_groups_per_row + 3) // 4
tma_aligned_M = ((M + 3) // 4) * 4
if output_q is None:
output_q = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
output_scale_packed = torch.zeros(
(num_packed_groups, tma_aligned_M),
dtype=torch.int32,
device=input.device,
).T[:M, :]
BLOCK_M = 8
grid = (num_packed_groups, (M + BLOCK_M - 1) // BLOCK_M)
num_warps = max(4, group_size // 32)
num_stages = 2
has_clamp = clamp_limit is not None
_silu_mul_quant_fp8_packed_kernel[grid](
input,
output_q,
output_scale_packed,
M,
input.stride(0),
output_q.stride(0),
output_scale_packed.stride(1),
clamp_limit if has_clamp else 0.0,
N=N,
NUM_GROUPS=num_groups_per_row,
fp8_min=fp8_min,
fp8_max=fp8_max,
GROUP_SIZE=group_size,
BLOCK_M=BLOCK_M,
HAS_CLAMP=has_clamp,
num_warps=num_warps,
num_stages=num_stages,
)
return output_q, output_scale_packed
@triton.jit @triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor( def _silu_mul_per_token_group_quant_fp8_colmajor(
y_ptr, # [M, N] y_ptr, # [M, N]
...@@ -823,19 +965,65 @@ def requant_weight_ue8m0_inplace( ...@@ -823,19 +965,65 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant) s_old.copy_(s_requant)
def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
"""Upcast E8M0 (exponent-only) scale to float32.
E8M0 stores only the 8-bit biased exponent (bias=127). To convert
to float32 we place those 8 bits into the exponent field of an
IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0.
"""
exp_bits = scale.view(torch.uint8).to(torch.int32)
fp32_bits = exp_bits << 23
return fp32_bits.view(torch.float32)
def deepgemm_post_process_fp8_weight_block( def deepgemm_post_process_fp8_weight_block(
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool wq: torch.Tensor,
ws: torch.Tensor,
quant_block_shape: tuple[int, ...],
use_e8m0: bool,
is_bmm: bool = False,
bmm_batch_size: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert wq.dtype == torch.float8_e4m3fn, ( assert wq.dtype == torch.float8_e4m3fn, (
"Expected quantized tensor dtype " "Expected quantized tensor dtype "
f"to be torch.float8_e4m3fn, got {wq.dtype} instead." f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
) )
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
)
if use_e8m0: if ws.dtype == torch.float8_e8m0fnu:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape) # Scales already in E8M0 from checkpoint — upcast to fp32
# and skip requantization (weights already have power-of-two scales).
ws = _upcast_e8m0_to_fp32(ws)
else:
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32 or "
f"torch.float8_e8m0fnu, got {ws.dtype} instead"
)
if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
if is_bmm:
# Reshape 2D weight/scale to 3D for grouped BMM (einsum):
# wq: (g*r, d) -> (g, r, d)
# ws: (g*r/128, d/128) -> (g, r/128, d/128)
g = bmm_batch_size
assert wq.ndim == 2 and ws.ndim == 2
d = wq.size(1)
r = wq.size(0) // g
wq = wq.view(g, r, d)
ws = ws.view(g, r // quant_block_shape[0], d // quant_block_shape[1])
# Pre-transform scale with recipe=(1, 128, 128) to broadcast + pack
# into TMA-aligned UE8M0 (INT32) layout. At runtime fp8_einsum uses
# recipe=(1, 1, 128) which sees INT dtype and skips re-transform.
dg_ws = transform_sf_into_required_layout(
sf=ws,
mn=r,
k=d,
recipe=(1, quant_block_shape[0], quant_block_shape[1]),
num_groups=g,
is_sfa=False,
)
return wq, dg_ws
original_ndim = wq.ndim original_ndim = wq.ndim
if wq.ndim == 2: if wq.ndim == 2:
...@@ -984,11 +1172,13 @@ def create_fp8_scale_parameter( ...@@ -984,11 +1172,13 @@ def create_fp8_scale_parameter(
input_size_per_partition: int, input_size_per_partition: int,
block_size: list[int] | None, block_size: list[int] | None,
weight_loader: Callable | None, weight_loader: Callable | None,
scale_dtype: torch.dtype | None = None,
) -> torch.nn.Parameter: ) -> torch.nn.Parameter:
"""Create scale parameter based on quantization strategy.""" """Create scale parameter based on quantization strategy."""
dtype = scale_dtype if scale_dtype is not None else torch.float32
if parameter_type == ChannelQuantScaleParameter: if parameter_type == ChannelQuantScaleParameter:
scale = parameter_type( scale = parameter_type(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), data=torch.empty((sum(output_partition_sizes), 1), dtype=dtype),
output_dim=0, output_dim=0,
weight_loader=weight_loader, weight_loader=weight_loader,
) )
...@@ -1000,7 +1190,7 @@ def create_fp8_scale_parameter( ...@@ -1000,7 +1190,7 @@ def create_fp8_scale_parameter(
data=torch.empty( data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n, (output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k, (input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32, dtype=dtype,
), ),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
...@@ -1008,13 +1198,14 @@ def create_fp8_scale_parameter( ...@@ -1008,13 +1198,14 @@ def create_fp8_scale_parameter(
) )
elif parameter_type == PerTensorScaleParameter: elif parameter_type == PerTensorScaleParameter:
scale = parameter_type( scale = parameter_type(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32), data=torch.empty(len(output_partition_sizes), dtype=dtype),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
else: else:
raise ValueError(f"Unknown parameter type: {parameter_type}") raise ValueError(f"Unknown parameter type: {parameter_type}")
scale[:] = torch.finfo(torch.float32).min if dtype == torch.float32:
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"}) set_weight_attrs(scale, {"scale_type": "weight_scale"})
return scale return scale
......
...@@ -7,7 +7,10 @@ from typing import Any ...@@ -7,7 +7,10 @@ from typing import Any
import torch import torch
from .base import RotaryEmbedding from .base import RotaryEmbedding
from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding from .deepseek_scaling_rope import (
DeepseekScalingRotaryEmbedding,
DeepseekV4ScalingRotaryEmbedding,
)
from .dual_chunk_rope import DualChunkRotaryEmbedding from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
...@@ -60,11 +63,13 @@ def get_rope( ...@@ -60,11 +63,13 @@ def get_rope(
rope_parameters = rope_parameters or {} rope_parameters = rope_parameters or {}
base = rope_parameters.get("rope_theta", 10000) base = rope_parameters.get("rope_theta", 10000)
scaling_type = rope_parameters.get("rope_type", "default") scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) if rotary_dim := rope_parameters.get("rope_dim", None):
pass
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0: else:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0") partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
rotary_dim = int(head_size * partial_rotary_factor) if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
key = ( key = (
head_size, head_size,
...@@ -289,7 +294,11 @@ def get_rope( ...@@ -289,7 +294,11 @@ def get_rope(
"mscale_all_dim", "mscale_all_dim",
) )
} }
rotary_emb = DeepseekScalingRotaryEmbedding( if rope_parameters.get("is_deepseek_v4", False):
cls = DeepseekV4ScalingRotaryEmbedding
else:
cls = DeepseekScalingRotaryEmbedding
rotary_emb = cls(
head_size, head_size,
rotary_dim, rotary_dim,
original_max_position, original_max_position,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment