Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b2ff4fb
Unverified
Commit
1b2ff4fb
authored
Sep 03, 2025
by
Yineng Zhang
Committed by
GitHub
Sep 03, 2025
Browse files
Revert "Optimized deepseek-v3/r1 model performance on mxfp4 run (#9671)" (#9959)
parent
2c7ca33a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
62 additions
and
458 deletions
+62
-458
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+5
-41
python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
...srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
+30
-49
python/sglang/srt/layers/quantization/quark/utils.py
python/sglang/srt/layers/quantization/quark/utils.py
+0
-97
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
+0
-13
python/sglang/srt/layers/rocm_linear_utils.py
python/sglang/srt/layers/rocm_linear_utils.py
+0
-44
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+27
-202
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+0
-12
No files found.
python/sglang/srt/layers/communicator.py
View file @
1b2ff4fb
...
...
@@ -42,22 +42,10 @@ from sglang.srt.layers.moe import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
is_flashinfer_available
,
is_gfx95_supported
,
is_hip
,
is_sm100_supported
,
)
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
,
is_sm100_supported
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
_is_gfx95_supported
=
is_gfx95_supported
()
if
_use_aiter
and
_is_gfx95_supported
:
from
sglang.srt.layers.quantization.rocm_mxfp4_utils
import
fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
...
...
@@ -213,7 +201,6 @@ class LayerCommunicator:
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
qaunt_format
:
str
=
""
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
...
...
@@ -231,34 +218,11 @@ class LayerCommunicator:
else
:
if
residual
is
None
:
residual
=
hidden_states
if
_use_aiter
and
_is_gfx95_supported
and
(
"mxfp4"
in
qaunt_format
):
hidden_states
=
fused_rms_mxfp4_quant
(
hidden_states
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
variance_epsilon
,
None
,
None
,
None
,
None
,
)
else
:
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
if
_use_aiter
and
_is_gfx95_supported
and
(
"mxfp4"
in
qaunt_format
):
hidden_states
,
residual
=
fused_rms_mxfp4_quant
(
hidden_states
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
variance_epsilon
,
None
,
None
,
None
,
residual
,
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
_communicate_simple_fn
(
hidden_states
=
hidden_states
,
...
...
python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
View file @
1b2ff4fb
...
...
@@ -8,7 +8,6 @@ import torch.nn.functional as F
from
aiter.ops.gemm_op_a4w4
import
gemm_a4w4
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.triton.gemm_afp4wfp4
import
gemm_afp4wfp4
from
aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic
import
gemm_afp4wfp4_pre_quant
from
aiter.ops.triton.quant
import
dynamic_mxfp4_quant
from
aiter.utility
import
dtypes
from
aiter.utility.fp4_utils
import
e8m0_shuffle
...
...
@@ -39,6 +38,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
return
# for aiter implement
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -85,53 +93,26 @@ class QuarkW4A4MXFP4(QuarkScheme):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# This path does not have support for bias currently
assert
bias
is
None
,
"bias is not supported"
three_d
=
False
x_s
=
None
y
=
None
if
isinsta
nc
e
(
x
,
tuple
):
assert
len
(
x
)
in
[
2
,
3
,
],
"For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
if
len
(
x
)
==
2
:
x
,
x_s
=
x
elif
len
(
x
)
==
3
:
x
,
x_s
,
y
=
x
use_fused_quant_gemm
=
(
x_
s
is
None
and
y
is
not
None
and
layer
.
weight
.
shape
[
0
]
==
y
.
shape
[
1
]
out_dtype
=
x
.
dtype
# M = x.shape[0]
# N = layer.weight.shape[0]
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# x, x_scales_shuffle = quant_fu
nc(x,
shuffle=True)
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
# return out[:M]
# triton implement
x_q
,
x_s
=
dynamic_mxfp4_quant
(
x
)
y
=
torch
.
empty
(
x_
q
.
shape
[
0
],
layer
.
weight
.
shape
[
0
],
device
=
x_q
.
device
,
dtype
=
out_dtype
)
if
x
.
dim
()
==
3
:
three_d
=
True
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output_shape
=
[
*
x
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
0
]]
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
# x_s is not None = true, x_q is uint8 num
if
use_fused_quant_gemm
or
x_s
is
not
None
:
x_q
=
x
else
:
x_q
,
x_s
=
dynamic_mxfp4_quant
(
x
)
if
y
is
None
:
y
=
torch
.
empty
(
x_q
.
shape
[
0
],
layer
.
weight
.
shape
[
0
],
device
=
x_q
.
device
,
dtype
=
self
.
out_dtype
,
)
if
use_fused_quant_gemm
:
gemm_afp4wfp4_pre_quant
(
x_q
,
layer
.
weight
,
layer
.
weight_scale
,
y
.
dtype
,
y
)
y
=
y
.
to
(
x
.
dtype
)
else
:
gemm_afp4wfp4
(
x_q
,
layer
.
weight
,
x_s
,
layer
.
weight_scale
,
self
.
out_dtype
,
y
)
if
three_d
:
return
y
.
view
(
*
output_shape
)
return
y
out
=
gemm_afp4wfp4
(
x_q
,
layer
.
weight
,
x_s
,
layer
.
weight_scale
,
out_dtype
,
y
)
return
out
python/sglang/srt/layers/quantization/quark/utils.py
View file @
1b2ff4fb
...
...
@@ -5,10 +5,6 @@ from collections.abc import Iterable, Mapping
from
types
import
MappingProxyType
from
typing
import
Any
,
Optional
import
torch
from
aiter.ops.triton.quant
import
dynamic_mxfp4_quant
from
torch
import
nn
def
deep_compare
(
dict1
:
Any
,
dict2
:
Any
)
->
bool
:
if
type
(
dict1
)
is
not
type
(
dict2
):
...
...
@@ -109,96 +105,3 @@ def _is_equal_or_regex_match(
elif
target
==
value
:
return
True
return
False
# utility for tensor dims > 2 cases
def
b_dynamic_mxfp4_quant
(
x
):
h
,
b
,
d
=
x
.
shape
x
,
x_scales
=
dynamic_mxfp4_quant
(
x
.
reshape
(
-
1
,
d
))
return
x
.
view
(
h
,
b
,
d
//
2
),
x_scales
.
view
(
h
,
b
,
d
//
32
)
def
mxfp4_to_f32
(
x
,
is_threed
):
# 2 because we pack fp4 in uint8.
x
=
x
.
repeat_interleave
(
2
,
dim
=-
1
)
if
is_threed
:
x
[...,
::
2
]
=
x
[...,
::
2
]
&
0xF
x
[...,
1
::
2
]
=
x
[...,
1
::
2
]
>>
4
else
:
x
[:,
::
2
]
=
x
[:,
::
2
]
&
0xF
x
[:,
1
::
2
]
=
x
[:,
1
::
2
]
>>
4
mxfp4_list
=
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
,
]
mxfp4_in_f32
=
torch
.
tensor
(
mxfp4_list
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
return
mxfp4_in_f32
[
x
.
long
()]
def
e8m0_to_f32
(
x
):
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
x_f32
=
2
**
((
x
.
to
(
torch
.
float32
))
-
127
)
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.
x_f32
[
x_f32
==
128
]
=
float
(
"nan"
)
return
x_f32
def
quark_post_load_weights
(
self_attn
:
nn
.
Module
,
w
:
torch
.
Tensor
,
quant_format
:
str
):
if
"mxfp4"
in
quant_format
:
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
if
w
.
dtype
==
torch
.
bfloat16
:
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
w_kc
,
w_s_kc
=
b_dynamic_mxfp4_quant
(
w_kc
.
transpose
(
-
2
,
-
1
))
w_kc
=
w_kc
.
transpose
(
-
2
,
-
1
)
w_s_kc
=
w_s_kc
.
transpose
(
-
2
,
-
1
)
w_vc
,
w_s_vc
=
b_dynamic_mxfp4_quant
(
w_vc
)
w_s_kc
=
w_s_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
w_s_vc
=
w_s_vc
.
contiguous
().
transpose
(
1
,
2
)
elif
w
.
dtype
==
torch
.
uint8
:
# static quant for mxfp4
# when dtype is uint8, it means the w has been quantized to mxfp4 format
# but we must separate it to w_kc and w_vc.
# The quantized tensor size is only half of original tensor size
# and the scaling factor is 1/32, the transpose behavior will be not correct
# need to upcast it to fp32 to separate w to w_kc and w_vc
# to ensure the following transpose behavior is correct
# and then do mxfp4 quant again
w
=
mxfp4_to_f32
(
w
,
True
).
to
(
torch
.
bfloat16
)
w_scales
=
self_attn
.
kv_b_proj
.
weight_scale
.
repeat_interleave
(
32
,
dim
=-
1
)
w_scales
=
e8m0_to_f32
(
w_scales
).
to
(
torch
.
bfloat16
)
w
=
w
*
w_scales
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
(
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
))
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
w_kc
,
w_s_kc
=
b_dynamic_mxfp4_quant
(
w_kc
.
transpose
(
-
2
,
-
1
))
w_kc
=
w_kc
.
transpose
(
-
2
,
-
1
)
w_s_kc
=
w_s_kc
.
transpose
(
-
2
,
-
1
)
w_vc
,
w_s_vc
=
b_dynamic_mxfp4_quant
(
w_vc
)
w_s_kc
=
w_s_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
w_s_vc
=
w_s_vc
.
contiguous
().
transpose
(
1
,
2
)
return
w_kc
,
w_s_kc
,
w_vc
,
w_s_vc
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
deleted
100644 → 0
View file @
2c7ca33a
from
aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant
import
(
batched_gemm_afp4wfp4_pre_quant
,
)
from
aiter.ops.triton.fused_mxfp4_quant
import
(
fused_flatten_mxfp4_quant
,
fused_rms_mxfp4_quant
,
)
__all__
=
[
"fused_rms_mxfp4_quant"
,
"fused_flatten_mxfp4_quant"
,
"batched_gemm_afp4wfp4_pre_quant"
,
]
python/sglang/srt/layers/rocm_linear_utils.py
deleted
100644 → 0
View file @
2c7ca33a
import
torch
from
aiter.ops.triton.fused_qk_concat
import
fused_qk_rope_cat
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
from
aiter.ops.triton.gemm_a16w16_atomic
import
gemm_a16w16_atomic
from
sglang.srt.utils
import
BumpAllocator
__all__
=
[
"fused_qk_rope_cat"
]
def
aiter_dsv3_router_gemm
(
hidden_states
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
):
M
=
hidden_states
.
shape
[
0
]
N
=
weight
.
shape
[
0
]
y
=
None
if
M
<=
256
:
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
# for now it is also coupled with zero allocator.
if
gemm_output_zero_allocator
!=
None
:
y
=
gemm_output_zero_allocator
.
allocate
(
M
*
N
).
view
(
M
,
N
)
else
:
y
=
torch
.
zeros
((
M
,
N
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
if
y
is
not
None
:
logits
=
gemm_a16w16_atomic
(
hidden_states
,
weight
,
y
=
y
).
to
(
hidden_states
.
dtype
)
else
:
logits
=
gemm_a16w16
(
hidden_states
,
weight
)
return
logits
def
get_dsv3_gemm_output_zero_allocator_size
(
n_routed_experts
:
int
,
num_moe_layers
:
int
,
allocate_size
:
int
,
embedding_dim
:
int
):
if
embedding_dim
!=
7168
or
n_routed_experts
!=
256
:
return
0
per_layer_size
=
256
*
(
allocate_size
+
n_routed_experts
)
return
num_moe_layers
*
per_layer_size
python/sglang/srt/models/deepseek_v2.py
View file @
1b2ff4fb
...
...
@@ -112,7 +112,6 @@ from sglang.srt.utils import (
is_cpu
,
is_cuda
,
is_flashinfer_available
,
is_gfx95_supported
,
is_hip
,
is_non_idle_and_non_empty
,
is_npu
,
...
...
@@ -130,22 +129,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_device_sm
=
get_device_sm
()
_is_gfx95_supported
=
is_gfx95_supported
()
_use_aiter_gfx95
=
_use_aiter
and
_is_gfx95_supported
if
_use_aiter_gfx95
:
from
sglang.srt.layers.quantization.quark.utils
import
quark_post_load_weights
from
sglang.srt.layers.quantization.rocm_mxfp4_utils
import
(
batched_gemm_afp4wfp4_pre_quant
,
fused_flatten_mxfp4_quant
,
fused_rms_mxfp4_quant
,
)
from
sglang.srt.layers.rocm_linear_utils
import
(
aiter_dsv3_router_gemm
,
fused_qk_rope_cat
,
get_dsv3_gemm_output_zero_allocator_size
,
)
if
_is_cuda
:
from
sgl_kernel
import
(
...
...
@@ -241,17 +224,10 @@ class DeepseekV2MLP(nn.Module):
forward_batch
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
if
gemm_output_zero_allocator
!=
None
and
x
.
shape
[
0
]
<=
256
:
y
=
gemm_output_zero_allocator
.
allocate
(
x
.
shape
[
0
]
*
self
.
gate_up_proj
.
output_size_per_partition
).
view
(
x
.
shape
[
0
],
self
.
gate_up_proj
.
output_size_per_partition
)
x
=
(
x
,
None
,
y
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
...
...
@@ -281,7 +257,7 @@ class MoEGate(nn.Module):
if
_is_cpu
and
_is_cpu_amx_available
:
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
def
forward
(
self
,
hidden_states
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
):
def
forward
(
self
,
hidden_states
):
if
use_intel_amx_backend
(
self
):
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
,
...
...
@@ -300,10 +276,6 @@ class MoEGate(nn.Module):
):
# router gemm output float32
logits
=
dsv3_router_gemm
(
hidden_states
,
self
.
weight
)
elif
_use_aiter_gfx95
and
hidden_states
.
shape
[
0
]
<=
256
:
logits
=
aiter_dsv3_router_gemm
(
hidden_states
,
self
.
weight
,
gemm_output_zero_allocator
)
else
:
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
...
...
@@ -467,7 +439,6 @@ class DeepseekV2MoE(nn.Module):
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
if
not
self
.
_enable_deepep_moe
:
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
...
...
@@ -481,14 +452,12 @@ class DeepseekV2MoE(nn.Module):
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
gemm_output_zero_allocator
,
)
else
:
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
gemm_output_zero_allocator
,
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -498,7 +467,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
...
...
@@ -507,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
:
...
...
@@ -534,7 +502,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
...
...
@@ -544,7 +511,7 @@ class DeepseekV2MoE(nn.Module):
if
hidden_states
.
shape
[
0
]
>
0
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
...
...
@@ -1130,19 +1097,11 @@ class DeepseekV2AttentionMLA(nn.Module):
if
self
.
attn_mha
.
kv_b_proj
is
None
:
self
.
attn_mha
.
kv_b_proj
=
self
.
kv_b_proj
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
if
isinstance
(
hidden_states
,
tuple
):
if
hidden_states
[
0
].
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
[
0
]
else
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
,
None
,
forward_batch
,
None
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
,
None
,
forward_batch
,
None
attn_forward_method
=
self
.
dispatch_attn_forward_method
(
forward_batch
)
...
...
@@ -1266,11 +1225,7 @@ class DeepseekV2AttentionMLA(nn.Module):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
if
(
(
not
isinstance
(
hidden_states
,
tuple
))
and
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
):
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
fused_qkv_a_proj_out
=
dsv3_fused_a_gemm
(
hidden_states
,
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
T
)
...
...
@@ -1290,18 +1245,8 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
if
_use_aiter_gfx95
and
self
.
q_b_proj
.
weight
.
dtype
==
torch
.
uint8
:
q
,
k_nope
=
fused_rms_mxfp4_quant
(
q
,
self
.
q_a_layernorm
.
weight
,
self
.
q_a_layernorm
.
variance_epsilon
,
k_nope
,
self
.
kv_a_layernorm
.
weight
,
self
.
kv_a_layernorm
.
variance_epsilon
,
)
else
:
q
=
self
.
q_a_layernorm
(
q
)
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
q
=
self
.
q_a_layernorm
(
q
)
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
k_nope
=
k_nope
.
unsqueeze
(
1
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
...
...
@@ -1333,27 +1278,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
[:,
:
expected_m
,
:]
elif
_is_hip
:
# TODO(haishaw): add bmm_fp8 to ROCm
if
_use_aiter_gfx95
and
self
.
w_kc
.
dtype
==
torch
.
uint8
:
x
=
q_nope
.
transpose
(
0
,
1
)
q_nope_out
=
torch
.
empty
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
w_kc
.
shape
[
2
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
,
)
batched_gemm_afp4wfp4_pre_quant
(
x
,
self
.
w_kc
.
transpose
(
-
2
,
-
1
),
self
.
w_scale_k
.
transpose
(
-
2
,
-
1
),
torch
.
bfloat16
,
q_nope_out
,
)
else
:
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
...
...
@@ -1367,15 +1295,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
if
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
)
and
(
not
_use_aiter
or
not
_is_gfx95_supported
):
if
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
,
positions
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
,
positions
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
(
self
.
current_attention_backend
==
"fa3"
...
...
@@ -1400,23 +1326,8 @@ class DeepseekV2AttentionMLA(nn.Module):
**
extra_args
,
)
else
:
if
_use_aiter_gfx95
:
cos
=
self
.
rotary_emb
.
cos_cache
sin
=
self
.
rotary_emb
.
sin_cache
q
,
k
=
fused_qk_rope_cat
(
q_nope_out
,
q_pe
,
k_nope
,
k_pe
,
positions
,
cos
,
sin
,
self
.
rotary_emb
.
is_neox_style
,
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
q
,
k
,
k_nope
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
...
...
@@ -1441,34 +1352,11 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif
_is_hip
:
# TODO(haishaw): add bmm_fp8 to ROCm
if
_use_aiter_gfx95
and
self
.
w_vc
.
dtype
==
torch
.
uint8
:
x
=
attn_output
.
transpose
(
0
,
1
)
attn_bmm_output
=
torch
.
empty
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
w_vc
.
shape
[
2
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
,
)
batched_gemm_afp4wfp4_pre_quant
(
x
,
self
.
w_vc
.
transpose
(
-
2
,
-
1
),
self
.
w_scale_v
.
transpose
(
-
2
,
-
1
),
torch
.
bfloat16
,
attn_bmm_output
,
)
else
:
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
if
self
.
o_proj
.
weight
.
dtype
==
torch
.
uint8
:
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
)
attn_bmm_output
=
fused_flatten_mxfp4_quant
(
attn_bmm_output
)
else
:
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
...
...
@@ -1976,21 +1864,10 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
quant_format
=
(
"mxfp4"
if
_is_gfx95_supported
and
self
.
self_attn
.
fused_qkv_a_proj_with_mqa
.
weight
==
torch
.
uint8
else
""
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
,
quant_format
,
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
self_attn
(
...
...
@@ -2159,37 +2036,6 @@ class DeepseekV2Model(nn.Module):
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
gemm_output_zero_allocator_size
=
0
if
(
_use_aiter_gfx95
and
config
.
n_routed_experts
==
256
and
self
.
embed_tokens
.
embedding_dim
==
7168
):
num_moe_layers
=
sum
(
[
1
for
i
in
range
(
len
(
self
.
layers
))
if
isinstance
(
self
.
layers
[
i
].
mlp
,
DeepseekV2MoE
)
]
)
allocate_size
=
0
for
i
in
range
(
len
(
self
.
layers
)):
if
isinstance
(
self
.
layers
[
i
].
mlp
,
DeepseekV2MoE
):
allocate_size
=
self
.
layers
[
i
].
mlp
.
shared_experts
.
gate_up_proj
.
output_size_per_partition
break
self
.
gemm_output_zero_allocator_size
=
(
get_dsv3_gemm_output_zero_allocator_size
(
config
.
n_routed_experts
,
num_moe_layers
,
allocate_size
,
self
.
embed_tokens
.
embedding_dim
,
)
)
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
...
...
@@ -2209,16 +2055,6 @@ class DeepseekV2Model(nn.Module):
device
=
device
,
)
gemm_output_zero_allocator
=
(
BumpAllocator
(
buffer_size
=
self
.
gemm_output_zero_allocator_size
,
dtype
=
torch
.
float32
,
device
=
device
,
)
if
self
.
gemm_output_zero_allocator_size
>
0
else
None
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -2245,12 +2081,7 @@ class DeepseekV2Model(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
,
gemm_output_zero_allocator
,
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
normal_end_layer
!=
self
.
end_layer
:
...
...
@@ -2523,12 +2354,6 @@ class DeepseekV2ForCausalLM(nn.Module):
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
if
_use_aiter_gfx95
and
self
.
quant_config
.
get_name
()
==
"quark"
:
w_kc
,
self_attn
.
w_scale_k
,
w_vc
,
self_attn
.
w_scale_v
=
(
quark_post_load_weights
(
self_attn
,
w
,
"mxfp4"
)
)
if
not
use_deep_gemm_bmm
:
self_attn
.
w_kc
=
bind_or_assign
(
self_attn
.
w_kc
,
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
...
...
python/sglang/srt/utils.py
View file @
1b2ff4fb
...
...
@@ -2900,18 +2900,6 @@ def mxfp_supported():
return
False
@
lru_cache
(
maxsize
=
1
)
def
is_gfx95_supported
():
"""
Returns whether the current platform supports MX types.
"""
if
torch
.
version
.
hip
:
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
return
any
(
gfx
in
gcn_arch
for
gfx
in
[
"gfx95"
])
else
:
return
False
# LoRA-related constants and utilities
SUPPORTED_LORA_TARGET_MODULES
=
[
"q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment