Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b2ff4fb
"hip/global_exchange.cpp" did not exist on "295a615aacce7e54a37e7935274ba15e901c78e4"
Unverified
Commit
1b2ff4fb
authored
Sep 03, 2025
by
Yineng Zhang
Committed by
GitHub
Sep 03, 2025
Browse files
Revert "Optimized deepseek-v3/r1 model performance on mxfp4 run (#9671)" (#9959)
parent
2c7ca33a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
62 additions
and
458 deletions
+62
-458
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+5
-41
python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
...srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
+30
-49
python/sglang/srt/layers/quantization/quark/utils.py
python/sglang/srt/layers/quantization/quark/utils.py
+0
-97
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
+0
-13
python/sglang/srt/layers/rocm_linear_utils.py
python/sglang/srt/layers/rocm_linear_utils.py
+0
-44
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+27
-202
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+0
-12
No files found.
python/sglang/srt/layers/communicator.py
View file @
1b2ff4fb
...
...
@@ -42,22 +42,10 @@ from sglang.srt.layers.moe import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
is_flashinfer_available
,
is_gfx95_supported
,
is_hip
,
is_sm100_supported
,
)
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
,
is_sm100_supported
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
_is_gfx95_supported
=
is_gfx95_supported
()
if
_use_aiter
and
_is_gfx95_supported
:
from
sglang.srt.layers.quantization.rocm_mxfp4_utils
import
fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
...
...
@@ -213,7 +201,6 @@ class LayerCommunicator:
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
qaunt_format
:
str
=
""
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
...
...
@@ -231,30 +218,7 @@ class LayerCommunicator:
else
:
if
residual
is
None
:
residual
=
hidden_states
if
_use_aiter
and
_is_gfx95_supported
and
(
"mxfp4"
in
qaunt_format
):
hidden_states
=
fused_rms_mxfp4_quant
(
hidden_states
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
variance_epsilon
,
None
,
None
,
None
,
None
,
)
else
:
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
if
_use_aiter
and
_is_gfx95_supported
and
(
"mxfp4"
in
qaunt_format
):
hidden_states
,
residual
=
fused_rms_mxfp4_quant
(
hidden_states
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
variance_epsilon
,
None
,
None
,
None
,
residual
,
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
...
...
python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
View file @
1b2ff4fb
...
...
@@ -8,7 +8,6 @@ import torch.nn.functional as F
from
aiter.ops.gemm_op_a4w4
import
gemm_a4w4
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.triton.gemm_afp4wfp4
import
gemm_afp4wfp4
from
aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic
import
gemm_afp4wfp4_pre_quant
from
aiter.ops.triton.quant
import
dynamic_mxfp4_quant
from
aiter.utility
import
dtypes
from
aiter.utility.fp4_utils
import
e8m0_shuffle
...
...
@@ -39,6 +38,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
return
# for aiter implement
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -85,53 +93,26 @@ class QuarkW4A4MXFP4(QuarkScheme):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# This path does not have support for bias currently
assert
bias
is
None
,
"bias is not supported"
three_d
=
False
x_s
=
None
y
=
None
if
isinstance
(
x
,
tuple
):
assert
len
(
x
)
in
[
2
,
3
,
],
"For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
if
len
(
x
)
==
2
:
x
,
x_s
=
x
elif
len
(
x
)
==
3
:
x
,
x_s
,
y
=
x
use_fused_quant_gemm
=
(
x_s
is
None
and
y
is
not
None
and
layer
.
weight
.
shape
[
0
]
==
y
.
shape
[
1
]
)
if
x
.
dim
()
==
3
:
three_d
=
True
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output_shape
=
[
*
x
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
0
]]
out_dtype
=
x
.
dtype
# M = x.shape[0]
# N = layer.weight.shape[0]
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
# x_s is not None = true, x_q is uint8 num
if
use_fused_quant_gemm
or
x_s
is
not
None
:
x_q
=
x
else
:
x_q
,
x_s
=
dynamic_mxfp4_quant
(
x
)
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# x, x_scales_shuffle = quant_func(x, shuffle=True)
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
if
y
is
None
:
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
# return out[:M]
# triton implement
x_q
,
x_s
=
dynamic_mxfp4_quant
(
x
)
y
=
torch
.
empty
(
x_q
.
shape
[
0
],
layer
.
weight
.
shape
[
0
],
device
=
x_q
.
device
,
dtype
=
self
.
out_dtype
,
x_q
.
shape
[
0
],
layer
.
weight
.
shape
[
0
],
device
=
x_q
.
device
,
dtype
=
out_dtype
)
if
use_fused_quant_gemm
:
gemm_afp4wfp4_pre_quant
(
x_q
,
layer
.
weight
,
layer
.
weight_scale
,
y
.
dtype
,
y
)
y
=
y
.
to
(
x
.
dtype
)
else
:
gemm_afp4wfp4
(
x_q
,
layer
.
weight
,
x_s
,
layer
.
weight_scale
,
self
.
out_dtype
,
y
)
if
three_d
:
return
y
.
view
(
*
output_shape
)
out
=
gemm_afp4wfp4
(
x_q
,
layer
.
weight
,
x_s
,
layer
.
weight_scale
,
out_dtype
,
y
)
return
y
return
out
python/sglang/srt/layers/quantization/quark/utils.py
View file @
1b2ff4fb
...
...
@@ -5,10 +5,6 @@ from collections.abc import Iterable, Mapping
from
types
import
MappingProxyType
from
typing
import
Any
,
Optional
import
torch
from
aiter.ops.triton.quant
import
dynamic_mxfp4_quant
from
torch
import
nn
def
deep_compare
(
dict1
:
Any
,
dict2
:
Any
)
->
bool
:
if
type
(
dict1
)
is
not
type
(
dict2
):
...
...
@@ -109,96 +105,3 @@ def _is_equal_or_regex_match(
elif
target
==
value
:
return
True
return
False
# utility for tensor dims > 2 cases
def
b_dynamic_mxfp4_quant
(
x
):
h
,
b
,
d
=
x
.
shape
x
,
x_scales
=
dynamic_mxfp4_quant
(
x
.
reshape
(
-
1
,
d
))
return
x
.
view
(
h
,
b
,
d
//
2
),
x_scales
.
view
(
h
,
b
,
d
//
32
)
def
mxfp4_to_f32
(
x
,
is_threed
):
# 2 because we pack fp4 in uint8.
x
=
x
.
repeat_interleave
(
2
,
dim
=-
1
)
if
is_threed
:
x
[...,
::
2
]
=
x
[...,
::
2
]
&
0xF
x
[...,
1
::
2
]
=
x
[...,
1
::
2
]
>>
4
else
:
x
[:,
::
2
]
=
x
[:,
::
2
]
&
0xF
x
[:,
1
::
2
]
=
x
[:,
1
::
2
]
>>
4
mxfp4_list
=
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
,
]
mxfp4_in_f32
=
torch
.
tensor
(
mxfp4_list
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
return
mxfp4_in_f32
[
x
.
long
()]
def
e8m0_to_f32
(
x
):
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
x_f32
=
2
**
((
x
.
to
(
torch
.
float32
))
-
127
)
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.
x_f32
[
x_f32
==
128
]
=
float
(
"nan"
)
return
x_f32
def
quark_post_load_weights
(
self_attn
:
nn
.
Module
,
w
:
torch
.
Tensor
,
quant_format
:
str
):
if
"mxfp4"
in
quant_format
:
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
if
w
.
dtype
==
torch
.
bfloat16
:
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
w_kc
,
w_s_kc
=
b_dynamic_mxfp4_quant
(
w_kc
.
transpose
(
-
2
,
-
1
))
w_kc
=
w_kc
.
transpose
(
-
2
,
-
1
)
w_s_kc
=
w_s_kc
.
transpose
(
-
2
,
-
1
)
w_vc
,
w_s_vc
=
b_dynamic_mxfp4_quant
(
w_vc
)
w_s_kc
=
w_s_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
w_s_vc
=
w_s_vc
.
contiguous
().
transpose
(
1
,
2
)
elif
w
.
dtype
==
torch
.
uint8
:
# static quant for mxfp4
# when dtype is uint8, it means the w has been quantized to mxfp4 format
# but we must separate it to w_kc and w_vc.
# The quantized tensor size is only half of original tensor size
# and the scaling factor is 1/32, the transpose behavior will be not correct
# need to upcast it to fp32 to separate w to w_kc and w_vc
# to ensure the following transpose behavior is correct
# and then do mxfp4 quant again
w
=
mxfp4_to_f32
(
w
,
True
).
to
(
torch
.
bfloat16
)
w_scales
=
self_attn
.
kv_b_proj
.
weight_scale
.
repeat_interleave
(
32
,
dim
=-
1
)
w_scales
=
e8m0_to_f32
(
w_scales
).
to
(
torch
.
bfloat16
)
w
=
w
*
w_scales
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
(
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
))
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
w_kc
,
w_s_kc
=
b_dynamic_mxfp4_quant
(
w_kc
.
transpose
(
-
2
,
-
1
))
w_kc
=
w_kc
.
transpose
(
-
2
,
-
1
)
w_s_kc
=
w_s_kc
.
transpose
(
-
2
,
-
1
)
w_vc
,
w_s_vc
=
b_dynamic_mxfp4_quant
(
w_vc
)
w_s_kc
=
w_s_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
w_s_vc
=
w_s_vc
.
contiguous
().
transpose
(
1
,
2
)
return
w_kc
,
w_s_kc
,
w_vc
,
w_s_vc
python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py
deleted
100644 → 0
View file @
2c7ca33a
from
aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant
import
(
batched_gemm_afp4wfp4_pre_quant
,
)
from
aiter.ops.triton.fused_mxfp4_quant
import
(
fused_flatten_mxfp4_quant
,
fused_rms_mxfp4_quant
,
)
__all__
=
[
"fused_rms_mxfp4_quant"
,
"fused_flatten_mxfp4_quant"
,
"batched_gemm_afp4wfp4_pre_quant"
,
]
python/sglang/srt/layers/rocm_linear_utils.py
deleted
100644 → 0
View file @
2c7ca33a
import
torch
from
aiter.ops.triton.fused_qk_concat
import
fused_qk_rope_cat
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
from
aiter.ops.triton.gemm_a16w16_atomic
import
gemm_a16w16_atomic
from
sglang.srt.utils
import
BumpAllocator
__all__
=
[
"fused_qk_rope_cat"
]
def
aiter_dsv3_router_gemm
(
hidden_states
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
):
M
=
hidden_states
.
shape
[
0
]
N
=
weight
.
shape
[
0
]
y
=
None
if
M
<=
256
:
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
# for now it is also coupled with zero allocator.
if
gemm_output_zero_allocator
!=
None
:
y
=
gemm_output_zero_allocator
.
allocate
(
M
*
N
).
view
(
M
,
N
)
else
:
y
=
torch
.
zeros
((
M
,
N
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
if
y
is
not
None
:
logits
=
gemm_a16w16_atomic
(
hidden_states
,
weight
,
y
=
y
).
to
(
hidden_states
.
dtype
)
else
:
logits
=
gemm_a16w16
(
hidden_states
,
weight
)
return
logits
def
get_dsv3_gemm_output_zero_allocator_size
(
n_routed_experts
:
int
,
num_moe_layers
:
int
,
allocate_size
:
int
,
embedding_dim
:
int
):
if
embedding_dim
!=
7168
or
n_routed_experts
!=
256
:
return
0
per_layer_size
=
256
*
(
allocate_size
+
n_routed_experts
)
return
num_moe_layers
*
per_layer_size
python/sglang/srt/models/deepseek_v2.py
View file @
1b2ff4fb
...
...
@@ -112,7 +112,6 @@ from sglang.srt.utils import (
is_cpu
,
is_cuda
,
is_flashinfer_available
,
is_gfx95_supported
,
is_hip
,
is_non_idle_and_non_empty
,
is_npu
,
...
...
@@ -130,22 +129,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_device_sm
=
get_device_sm
()
_is_gfx95_supported
=
is_gfx95_supported
()
_use_aiter_gfx95
=
_use_aiter
and
_is_gfx95_supported
if
_use_aiter_gfx95
:
from
sglang.srt.layers.quantization.quark.utils
import
quark_post_load_weights
from
sglang.srt.layers.quantization.rocm_mxfp4_utils
import
(
batched_gemm_afp4wfp4_pre_quant
,
fused_flatten_mxfp4_quant
,
fused_rms_mxfp4_quant
,
)
from
sglang.srt.layers.rocm_linear_utils
import
(
aiter_dsv3_router_gemm
,
fused_qk_rope_cat
,
get_dsv3_gemm_output_zero_allocator_size
,
)
if
_is_cuda
:
from
sgl_kernel
import
(
...
...
@@ -241,17 +224,10 @@ class DeepseekV2MLP(nn.Module):
forward_batch
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
if
gemm_output_zero_allocator
!=
None
and
x
.
shape
[
0
]
<=
256
:
y
=
gemm_output_zero_allocator
.
allocate
(
x
.
shape
[
0
]
*
self
.
gate_up_proj
.
output_size_per_partition
).
view
(
x
.
shape
[
0
],
self
.
gate_up_proj
.
output_size_per_partition
)
x
=
(
x
,
None
,
y
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
...
...
@@ -281,7 +257,7 @@ class MoEGate(nn.Module):
if
_is_cpu
and
_is_cpu_amx_available
:
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
def
forward
(
self
,
hidden_states
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
):
def
forward
(
self
,
hidden_states
):
if
use_intel_amx_backend
(
self
):
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
,
...
...
@@ -300,10 +276,6 @@ class MoEGate(nn.Module):
):
# router gemm output float32
logits
=
dsv3_router_gemm
(
hidden_states
,
self
.
weight
)
elif
_use_aiter_gfx95
and
hidden_states
.
shape
[
0
]
<=
256
:
logits
=
aiter_dsv3_router_gemm
(
hidden_states
,
self
.
weight
,
gemm_output_zero_allocator
)
else
:
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
...
...
@@ -467,7 +439,6 @@ class DeepseekV2MoE(nn.Module):
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
if
not
self
.
_enable_deepep_moe
:
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
...
...
@@ -481,14 +452,12 @@ class DeepseekV2MoE(nn.Module):
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
gemm_output_zero_allocator
,
)
else
:
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
gemm_output_zero_allocator
,
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -498,7 +467,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
...
...
@@ -507,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
:
...
...
@@ -534,7 +502,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
...
...
@@ -544,7 +511,7 @@ class DeepseekV2MoE(nn.Module):
if
hidden_states
.
shape
[
0
]
>
0
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
...
...
@@ -1130,14 +1097,6 @@ class DeepseekV2AttentionMLA(nn.Module):
if
self
.
attn_mha
.
kv_b_proj
is
None
:
self
.
attn_mha
.
kv_b_proj
=
self
.
kv_b_proj
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
if
isinstance
(
hidden_states
,
tuple
):
if
hidden_states
[
0
].
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
[
0
]
else
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
...
...
@@ -1266,11 +1225,7 @@ class DeepseekV2AttentionMLA(nn.Module):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
if
(
(
not
isinstance
(
hidden_states
,
tuple
))
and
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
):
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
fused_qkv_a_proj_out
=
dsv3_fused_a_gemm
(
hidden_states
,
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
T
)
...
...
@@ -1289,16 +1244,6 @@ class DeepseekV2AttentionMLA(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
if
_use_aiter_gfx95
and
self
.
q_b_proj
.
weight
.
dtype
==
torch
.
uint8
:
q
,
k_nope
=
fused_rms_mxfp4_quant
(
q
,
self
.
q_a_layernorm
.
weight
,
self
.
q_a_layernorm
.
variance_epsilon
,
k_nope
,
self
.
kv_a_layernorm
.
weight
,
self
.
kv_a_layernorm
.
variance_epsilon
,
)
else
:
q
=
self
.
q_a_layernorm
(
q
)
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
...
...
@@ -1333,23 +1278,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
[:,
:
expected_m
,
:]
elif
_is_hip
:
# TODO(haishaw): add bmm_fp8 to ROCm
if
_use_aiter_gfx95
and
self
.
w_kc
.
dtype
==
torch
.
uint8
:
x
=
q_nope
.
transpose
(
0
,
1
)
q_nope_out
=
torch
.
empty
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
w_kc
.
shape
[
2
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
,
)
batched_gemm_afp4wfp4_pre_quant
(
x
,
self
.
w_kc
.
transpose
(
-
2
,
-
1
),
self
.
w_scale_k
.
transpose
(
-
2
,
-
1
),
torch
.
bfloat16
,
q_nope_out
,
)
else
:
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
...
...
@@ -1367,15 +1295,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
if
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
)
and
(
not
_use_aiter
or
not
_is_gfx95_supported
):
if
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
,
positions
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
,
positions
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
(
self
.
current_attention_backend
==
"fa3"
...
...
@@ -1399,24 +1325,9 @@ class DeepseekV2AttentionMLA(nn.Module):
k_rope
=
k_pe
,
**
extra_args
,
)
else
:
if
_use_aiter_gfx95
:
cos
=
self
.
rotary_emb
.
cos_cache
sin
=
self
.
rotary_emb
.
sin_cache
q
,
k
=
fused_qk_rope_cat
(
q_nope_out
,
q_pe
,
k_nope
,
k_pe
,
positions
,
cos
,
sin
,
self
.
rotary_emb
.
is_neox_style
,
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
q
,
k
,
k_nope
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
...
...
@@ -1441,34 +1352,11 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif
_is_hip
:
# TODO(haishaw): add bmm_fp8 to ROCm
if
_use_aiter_gfx95
and
self
.
w_vc
.
dtype
==
torch
.
uint8
:
x
=
attn_output
.
transpose
(
0
,
1
)
attn_bmm_output
=
torch
.
empty
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
w_vc
.
shape
[
2
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
,
)
batched_gemm_afp4wfp4_pre_quant
(
x
,
self
.
w_vc
.
transpose
(
-
2
,
-
1
),
self
.
w_scale_v
.
transpose
(
-
2
,
-
1
),
torch
.
bfloat16
,
attn_bmm_output
,
)
else
:
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
if
self
.
o_proj
.
weight
.
dtype
==
torch
.
uint8
:
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
)
attn_bmm_output
=
fused_flatten_mxfp4_quant
(
attn_bmm_output
)
else
:
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
...
...
@@ -1976,21 +1864,10 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
quant_format
=
(
"mxfp4"
if
_is_gfx95_supported
and
self
.
self_attn
.
fused_qkv_a_proj_with_mqa
.
weight
==
torch
.
uint8
else
""
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
,
quant_format
,
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
self_attn
(
...
...
@@ -2159,37 +2036,6 @@ class DeepseekV2Model(nn.Module):
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
gemm_output_zero_allocator_size
=
0
if
(
_use_aiter_gfx95
and
config
.
n_routed_experts
==
256
and
self
.
embed_tokens
.
embedding_dim
==
7168
):
num_moe_layers
=
sum
(
[
1
for
i
in
range
(
len
(
self
.
layers
))
if
isinstance
(
self
.
layers
[
i
].
mlp
,
DeepseekV2MoE
)
]
)
allocate_size
=
0
for
i
in
range
(
len
(
self
.
layers
)):
if
isinstance
(
self
.
layers
[
i
].
mlp
,
DeepseekV2MoE
):
allocate_size
=
self
.
layers
[
i
].
mlp
.
shared_experts
.
gate_up_proj
.
output_size_per_partition
break
self
.
gemm_output_zero_allocator_size
=
(
get_dsv3_gemm_output_zero_allocator_size
(
config
.
n_routed_experts
,
num_moe_layers
,
allocate_size
,
self
.
embed_tokens
.
embedding_dim
,
)
)
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
...
...
@@ -2209,16 +2055,6 @@ class DeepseekV2Model(nn.Module):
device
=
device
,
)
gemm_output_zero_allocator
=
(
BumpAllocator
(
buffer_size
=
self
.
gemm_output_zero_allocator_size
,
dtype
=
torch
.
float32
,
device
=
device
,
)
if
self
.
gemm_output_zero_allocator_size
>
0
else
None
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -2245,12 +2081,7 @@ class DeepseekV2Model(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
,
gemm_output_zero_allocator
,
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
normal_end_layer
!=
self
.
end_layer
:
...
...
@@ -2523,12 +2354,6 @@ class DeepseekV2ForCausalLM(nn.Module):
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
if
_use_aiter_gfx95
and
self
.
quant_config
.
get_name
()
==
"quark"
:
w_kc
,
self_attn
.
w_scale_k
,
w_vc
,
self_attn
.
w_scale_v
=
(
quark_post_load_weights
(
self_attn
,
w
,
"mxfp4"
)
)
if
not
use_deep_gemm_bmm
:
self_attn
.
w_kc
=
bind_or_assign
(
self_attn
.
w_kc
,
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
...
...
python/sglang/srt/utils.py
View file @
1b2ff4fb
...
...
@@ -2900,18 +2900,6 @@ def mxfp_supported():
return
False
@
lru_cache
(
maxsize
=
1
)
def
is_gfx95_supported
():
"""
Returns whether the current platform supports MX types.
"""
if
torch
.
version
.
hip
:
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
return
any
(
gfx
in
gcn_arch
for
gfx
in
[
"gfx95"
])
else
:
return
False
# LoRA-related constants and utilities
SUPPORTED_LORA_TARGET_MODULES
=
[
"q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment