Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
613b197e
Unverified
Commit
613b197e
authored
Apr 20, 2025
by
fzyzcjy
Committed by
GitHub
Apr 19, 2025
Browse files
Remove one kernel in per_tensor_quant_mla_fp8 (#5549)
parent
d58e3544
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
18 deletions
+62
-18
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+9
-7
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+8
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+32
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
613b197e
...
...
@@ -58,10 +58,8 @@ if _is_cuda:
):
_enable_jit_deepgemm
=
True
logger
=
logging
.
getLogger
(
__name__
)
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
...
...
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
x
:
torch
.
Tensor
,
x_s_out
:
torch
.
Tensor
,
eps
:
float
=
1e-12
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
assert
(
x_s_out
.
shape
==
(
1
,)
and
x_s_out
.
dtype
==
torch
.
float32
and
x_s_out
.
device
==
x
.
device
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
...
...
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
_per_tensor_quant_mla_fp8_stage1
[
grid
](
x
,
x_s
,
x_s
_out
,
head_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
...
...
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
)
_per_tensor_quant_mla_fp8_stage2
[
grid
](
x
,
x_s
,
x_s
_out
,
x_q
,
num_seq
,
head_size
,
...
...
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE
,
)
return
x_q
,
x_s
return
x_q
,
x_s
_out
def
scaled_fp8_quant
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
613b197e
...
...
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
...
...
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
...
...
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual
=
None
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
613b197e
...
...
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
...
...
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class
AttnForwardMethod
(
IntEnum
):
# Use multi-head attention
MHA
=
auto
()
...
...
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
...
...
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal
(
self
,
...
...
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
...
...
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
...
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
...
...
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
...
...
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
...
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
...
...
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
else
:
raise
NotImplementedError
...
...
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
...
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
# Gather
...
...
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
...
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
if
self
.
attn_tp_size
!=
1
:
...
...
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size
=
len
(
self
.
layers
)
*
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder
.
set_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
...
...
python/sglang/srt/utils.py
View file @
613b197e
...
...
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM"
,
}
return
architectures
[
0
]
in
default_archs
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class
BumpAllocator
:
def
__init__
(
self
,
buffer_size
:
int
,
dtype
,
device
):
self
.
_buffer
=
torch
.
zeros
((
buffer_size
,),
dtype
=
dtype
,
device
=
device
)
self
.
_pointer
=
0
def
allocate
(
self
,
size
:
int
):
assert
self
.
_pointer
+
size
<=
len
(
self
.
_buffer
)
output
=
self
.
_buffer
[
self
.
_pointer
:
self
.
_pointer
+
size
]
self
.
_pointer
+=
size
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment