Unverified Commit e9697374 authored by kk's avatar kk Committed by GitHub
Browse files

Optimized deepseek-v3/r1 model performance on mxfp4 run (#10008)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHAI <hixiao@gmail.com>
Co-authored-by: default avatarHubert Lu <55214931+hubertlu-tw@users.noreply.github.com>
parent 93088b69
...@@ -43,8 +43,11 @@ from sglang.srt.layers.moe import ( ...@@ -43,8 +43,11 @@ from sglang.srt.layers.moe import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var,
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm90_supported, is_sm90_supported,
is_sm100_supported, is_sm100_supported,
) )
...@@ -52,6 +55,11 @@ from sglang.srt.utils import ( ...@@ -52,6 +55,11 @@ from sglang.srt.utils import (
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm90_supported = is_cuda() and is_sm90_supported() _is_sm90_supported = is_cuda() and is_sm90_supported()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()
if _use_aiter and _is_gfx95_supported:
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
...@@ -207,6 +215,7 @@ class LayerCommunicator: ...@@ -207,6 +215,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
qaunt_format: str = "",
): ):
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
residual = hidden_states residual = hidden_states
...@@ -224,11 +233,34 @@ class LayerCommunicator: ...@@ -224,11 +233,34 @@ class LayerCommunicator:
else: else:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
None,
)
else:
hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm( if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states, residual hidden_states, residual = fused_rms_mxfp4_quant(
) hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
residual,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states = self._communicate_simple_fn( hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from aiter.ops.gemm_op_a4w4 import gemm_a4w4 from aiter.ops.gemm_op_a4w4 import gemm_a4w4
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility import dtypes from aiter.utility import dtypes
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
...@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme): ...@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return return
# for aiter implement
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme): ...@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# This path does not have support for bias currently
out_dtype = x.dtype assert bias is None, "bias is not supported"
# M = x.shape[0]
# N = layer.weight.shape[0] three_d = False
x_s = None
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) y = None
# x, x_scales_shuffle = quant_func(x, shuffle=True) if isinstance(x, tuple):
assert len(x) in [
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype) 2,
3,
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias) ], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
if len(x) == 2:
# return out[:M] x, x_s = x
elif len(x) == 3:
# triton implement x, x_s, y = x
x_q, x_s = dynamic_mxfp4_quant(x)
y = torch.empty( use_fused_quant_gemm = (
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
) )
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y) if x.dim() == 3:
three_d = True
return out x = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
# x_s is not None = true, x_q is uint8 num
if use_fused_quant_gemm or x_s is not None:
x_q = x
else:
x_q, x_s = dynamic_mxfp4_quant(x)
if y is None:
y = torch.empty(
x_q.shape[0],
layer.weight.shape[0],
device=x_q.device,
dtype=self.out_dtype,
)
if use_fused_quant_gemm:
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
y = y.to(x.dtype)
else:
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
if three_d:
return y.view(*output_shape)
return y
...@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping ...@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
import torch
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from torch import nn
def deep_compare(dict1: Any, dict2: Any) -> bool: def deep_compare(dict1: Any, dict2: Any) -> bool:
if type(dict1) is not type(dict2): if type(dict1) is not type(dict2):
...@@ -105,3 +109,96 @@ def _is_equal_or_regex_match( ...@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
elif target == value: elif target == value:
return True return True
return False return False
# utility for tensor dims > 2 cases
def b_dynamic_mxfp4_quant(x):
h, b, d = x.shape
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
def mxfp4_to_f32(x, is_threed):
# 2 because we pack fp4 in uint8.
x = x.repeat_interleave(2, dim=-1)
if is_threed:
x[..., ::2] = x[..., ::2] & 0xF
x[..., 1::2] = x[..., 1::2] >> 4
else:
x[:, ::2] = x[:, ::2] & 0xF
x[:, 1::2] = x[:, 1::2] >> 4
mxfp4_list = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
return mxfp4_in_f32[x.long()]
def e8m0_to_f32(x):
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.
x_f32[x_f32 == 128] = float("nan")
return x_f32
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
if "mxfp4" in quant_format:
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
if w.dtype == torch.bfloat16:
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
elif w.dtype == torch.uint8: # static quant for mxfp4
# when dtype is uint8, it means the w has been quantized to mxfp4 format
# but we must separate it to w_kc and w_vc.
# The quantized tensor size is only half of original tensor size
# and the scaling factor is 1/32, the transpose behavior will be not correct
# need to upcast it to fp32 to separate w to w_kc and w_vc
# to ensure the following transpose behavior is correct
# and then do mxfp4 quant again
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
w = w * w_scales
w_kc, w_vc = w.unflatten(
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
return w_kc, w_s_kc, w_vc, w_s_vc
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
batched_gemm_afp4wfp4_pre_quant,
)
from aiter.ops.triton.fused_mxfp4_quant import (
fused_flatten_mxfp4_quant,
fused_rms_mxfp4_quant,
)
__all__ = [
"fused_rms_mxfp4_quant",
"fused_flatten_mxfp4_quant",
"batched_gemm_afp4wfp4_pre_quant",
]
import torch
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
from sglang.srt.utils import BumpAllocator
__all__ = ["fused_qk_rope_cat"]
def aiter_dsv3_router_gemm(
hidden_states: torch.Tensor,
weight: torch.Tensor,
gemm_output_zero_allocator: BumpAllocator = None,
):
M = hidden_states.shape[0]
N = weight.shape[0]
y = None
if M <= 256:
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
# for now it is also coupled with zero allocator.
if gemm_output_zero_allocator != None:
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
else:
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
if y is not None:
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
else:
logits = gemm_a16w16(hidden_states, weight)
return logits
def get_dsv3_gemm_output_zero_allocator_size(
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
):
if embedding_dim != 7168 or n_routed_experts != 256:
return 0
per_layer_size = 256 * (allocate_size + n_routed_experts)
return num_moe_layers * per_layer_size
...@@ -112,6 +112,7 @@ from sglang.srt.utils import ( ...@@ -112,6 +112,7 @@ from sglang.srt.utils import (
is_cpu, is_cpu,
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_gfx95_supported,
is_hip, is_hip,
is_non_idle_and_non_empty, is_non_idle_and_non_empty,
is_npu, is_npu,
...@@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_device_sm = get_device_sm() _device_sm = get_device_sm()
_is_gfx95_supported = is_gfx95_supported()
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
if _use_aiter_gfx95:
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
batched_gemm_afp4wfp4_pre_quant,
fused_flatten_mxfp4_quant,
fused_rms_mxfp4_quant,
)
from sglang.srt.layers.rocm_linear_utils import (
aiter_dsv3_router_gemm,
fused_qk_rope_cat,
get_dsv3_gemm_output_zero_allocator_size,
)
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module): ...@@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module):
forward_batch=None, forward_batch=None,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
): ):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
y = gemm_output_zero_allocator.allocate(
x.shape[0] * self.gate_up_proj.output_size_per_partition
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
x = (x, None, y)
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj( x, _ = self.down_proj(
...@@ -257,7 +281,7 @@ class MoEGate(nn.Module): ...@@ -257,7 +281,7 @@ class MoEGate(nn.Module):
if _is_cpu and _is_cpu_amx_available: if _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"]) self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states): def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
if use_intel_amx_backend(self): if use_intel_amx_backend(self):
return torch.ops.sgl_kernel.weight_packed_linear( return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states, hidden_states,
...@@ -276,6 +300,10 @@ class MoEGate(nn.Module): ...@@ -276,6 +300,10 @@ class MoEGate(nn.Module):
): ):
# router gemm output float32 # router gemm output float32
logits = dsv3_router_gemm(hidden_states, self.weight) logits = dsv3_router_gemm(hidden_states, self.weight)
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
logits = aiter_dsv3_router_gemm(
hidden_states, self.weight, gemm_output_zero_allocator
)
else: else:
logits = F.linear(hidden_states, self.weight, None) logits = F.linear(hidden_states, self.weight, None)
...@@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module):
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not self._enable_deepep_moe: if not self._enable_deepep_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024 DUAL_STREAM_TOKEN_THRESHOLD = 1024
...@@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module):
hidden_states, hidden_states,
should_allreduce_fusion, should_allreduce_fusion,
use_reduce_scatter, use_reduce_scatter,
gemm_output_zero_allocator,
) )
else: else:
return self.forward_normal( return self.forward_normal(
hidden_states, hidden_states,
should_allreduce_fusion, should_allreduce_fusion,
use_reduce_scatter, use_reduce_scatter,
gemm_output_zero_allocator,
) )
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda: if not _is_cuda:
...@@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
...@@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module):
return self.forward_cpu(hidden_states, should_allreduce_fusion) return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
else: else:
shared_output = None shared_output = None
...@@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states return final_hidden_states
def _forward_shared_experts(self, hidden_states): def _forward_shared_experts(
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
):
if self.num_fused_shared_experts == 0: if self.num_fused_shared_experts == 0:
return self.shared_experts(hidden_states) return self.shared_experts(
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
)
else: else:
return None return None
...@@ -1097,11 +1138,19 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1097,11 +1138,19 @@ class DeepseekV2AttentionMLA(nn.Module):
if self.attn_mha.kv_b_proj is None: if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0: # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
assert ( if isinstance(hidden_states, tuple):
not self.o_proj.reduce_results if hidden_states[0].shape[0] == 0:
), "short-circuiting allreduce will lead to hangs" assert (
return hidden_states, None, forward_batch, None not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states[0]
else:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch) attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
...@@ -1225,7 +1274,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1225,7 +1274,11 @@ class DeepseekV2AttentionMLA(nn.Module):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: if (
(not isinstance(hidden_states, tuple))
and hidden_states.shape[0] <= 16
and self.use_min_latency_fused_a_gemm
):
fused_qkv_a_proj_out = dsv3_fused_a_gemm( fused_qkv_a_proj_out = dsv3_fused_a_gemm(
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
) )
...@@ -1245,8 +1298,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1245,8 +1298,18 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope = self.kv_a_layernorm(k_nope) k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
else: else:
q = self.q_a_layernorm(q) if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
k_nope = self.kv_a_layernorm(k_nope) q, k_nope = fused_rms_mxfp4_quant(
q,
self.q_a_layernorm.weight,
self.q_a_layernorm.variance_epsilon,
k_nope,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.variance_epsilon,
)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
k_nope = k_nope.unsqueeze(1) k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
...@@ -1278,10 +1341,27 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1278,10 +1341,27 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out[:, :expected_m, :] q_nope_out = q_nope_out[:, :expected_m, :]
elif _is_hip: elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm # TODO(haishaw): add bmm_fp8 to ROCm
q_nope_out = torch.bmm( if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
q_nope.to(torch.bfloat16).transpose(0, 1), x = q_nope.transpose(0, 1)
self.w_kc.to(torch.bfloat16) * self.w_scale, q_nope_out = torch.empty(
) x.shape[0],
x.shape[1],
self.w_kc.shape[2],
device=x.device,
dtype=torch.bfloat16,
)
batched_gemm_afp4wfp4_pre_quant(
x,
self.w_kc.transpose(-2, -1),
self.w_scale_k.transpose(-2, -1),
torch.bfloat16,
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), q_nope.transpose(0, 1),
...@@ -1295,13 +1375,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1295,13 +1375,15 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch): if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
def forward_absorb_core( def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
): ):
if ( if (
self.current_attention_backend == "fa3" self.current_attention_backend == "fa3"
...@@ -1326,8 +1408,23 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1326,8 +1408,23 @@ class DeepseekV2AttentionMLA(nn.Module):
**extra_args, **extra_args,
) )
else: else:
q = torch.cat([q_nope_out, q_pe], dim=-1) if _use_aiter_gfx95:
k = torch.cat([k_nope, k_pe], dim=-1) cos = self.rotary_emb.cos_cache
sin = self.rotary_emb.sin_cache
q, k = fused_qk_rope_cat(
q_nope_out,
q_pe,
k_nope,
k_pe,
positions,
cos,
sin,
self.rotary_emb.is_neox_style,
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
...@@ -1352,11 +1449,34 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1352,11 +1449,34 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif _is_hip: elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm # TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output = torch.bmm( if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
attn_output.to(torch.bfloat16).transpose(0, 1), x = attn_output.transpose(0, 1)
self.w_vc.to(torch.bfloat16) * self.w_scale, attn_bmm_output = torch.empty(
) x.shape[0],
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) x.shape[1],
self.w_vc.shape[2],
device=x.device,
dtype=torch.bfloat16,
)
batched_gemm_afp4wfp4_pre_quant(
x,
self.w_vc.transpose(-2, -1),
self.w_scale_v.transpose(-2, -1),
torch.bfloat16,
attn_bmm_output,
)
else:
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
if self.o_proj.weight.dtype == torch.uint8:
attn_bmm_output = attn_bmm_output.transpose(0, 1)
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
else:
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), attn_output.transpose(0, 1),
...@@ -1866,10 +1986,21 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1866,10 +1986,21 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
quant_format = (
"mxfp4"
if _is_gfx95_supported
and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
else ""
)
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch hidden_states,
residual,
forward_batch,
quant_format,
) )
hidden_states = self.self_attn( hidden_states = self.self_attn(
...@@ -1893,8 +2024,16 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1893,8 +2024,16 @@ class DeepseekV2DecoderLayer(nn.Module):
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch forward_batch
) )
if isinstance(self.mlp, DeepseekV2MLP):
gemm_output_zero_allocator = None
hidden_states = self.mlp( hidden_states = self.mlp(
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter hidden_states,
forward_batch,
should_allreduce_fusion,
use_reduce_scatter,
gemm_output_zero_allocator,
) )
if should_allreduce_fusion: if should_allreduce_fusion:
...@@ -2038,6 +2177,37 @@ class DeepseekV2Model(nn.Module): ...@@ -2038,6 +2177,37 @@ class DeepseekV2Model(nn.Module):
else: else:
self.norm = PPMissingLayer(return_tuple=True) self.norm = PPMissingLayer(return_tuple=True)
self.gemm_output_zero_allocator_size = 0
if (
_use_aiter_gfx95
and config.n_routed_experts == 256
and self.embed_tokens.embedding_dim == 7168
):
num_moe_layers = sum(
[
1
for i in range(len(self.layers))
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
]
)
allocate_size = 0
for i in range(len(self.layers)):
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
allocate_size = self.layers[
i
].mlp.shared_experts.gate_up_proj.output_size_per_partition
break
self.gemm_output_zero_allocator_size = (
get_dsv3_gemm_output_zero_allocator_size(
config.n_routed_experts,
num_moe_layers,
allocate_size,
self.embed_tokens.embedding_dim,
)
)
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -2057,6 +2227,21 @@ class DeepseekV2Model(nn.Module): ...@@ -2057,6 +2227,21 @@ class DeepseekV2Model(nn.Module):
device=device, device=device,
) )
has_gemm_output_zero_allocator = hasattr(
self, "gemm_output_zero_allocator_size"
)
gemm_output_zero_allocator = (
BumpAllocator(
buffer_size=self.gemm_output_zero_allocator_size,
dtype=torch.float32,
device=device,
)
if has_gemm_output_zero_allocator
and self.gemm_output_zero_allocator_size > 0
else None
)
if self.pp_group.is_first_rank: if self.pp_group.is_first_rank:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -2083,7 +2268,12 @@ class DeepseekV2Model(nn.Module): ...@@ -2083,7 +2268,12 @@ class DeepseekV2Model(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer(i): with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator positions,
hidden_states,
forward_batch,
residual,
zero_allocator,
gemm_output_zero_allocator,
) )
if normal_end_layer != self.end_layer: if normal_end_layer != self.end_layer:
...@@ -2356,6 +2546,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2356,6 +2546,12 @@ class DeepseekV2ForCausalLM(nn.Module):
w_kc, w_vc = w.unflatten( w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
quark_post_load_weights(self_attn, w, "mxfp4")
)
if not use_deep_gemm_bmm: if not use_deep_gemm_bmm:
self_attn.w_kc = bind_or_assign( self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
......
...@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module): ...@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, should_allreduce_fusion=False): def forward(
self,
x,
forward_batch=None,
should_allreduce_fusion=False,
gemm_output_zero_allocator: BumpAllocator = None,
):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
...@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
...@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
......
...@@ -2900,6 +2900,18 @@ def mxfp_supported(): ...@@ -2900,6 +2900,18 @@ def mxfp_supported():
return False return False
@lru_cache(maxsize=1)
def is_gfx95_supported():
"""
Returns whether the current platform supports MX types.
"""
if torch.version.hip:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx95"])
else:
return False
# LoRA-related constants and utilities # LoRA-related constants and utilities
SUPPORTED_LORA_TARGET_MODULES = [ SUPPORTED_LORA_TARGET_MODULES = [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment