Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
613b197e
Unverified
Commit
613b197e
authored
Apr 20, 2025
by
fzyzcjy
Committed by
GitHub
Apr 19, 2025
Browse files
Remove one kernel in per_tensor_quant_mla_fp8 (#5549)
parent
d58e3544
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
18 deletions
+62
-18
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+9
-7
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+8
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+32
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
613b197e
...
@@ -58,10 +58,8 @@ if _is_cuda:
...
@@ -58,10 +58,8 @@ if _is_cuda:
):
):
_enable_jit_deepgemm
=
True
_enable_jit_deepgemm
=
True
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
supports_custom_op
():
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
def
deep_gemm_fp8_fp8_bf16_nt
(
...
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
...
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
def
per_tensor_quant_mla_fp8
(
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
x
:
torch
.
Tensor
,
x_s_out
:
torch
.
Tensor
,
eps
:
float
=
1e-12
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
This function quantizes input values to float8 values with tensor-wise quantization
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
and specialized for mla absorbed case.
"""
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
assert
(
x_s_out
.
shape
==
(
1
,)
and
x_s_out
.
dtype
==
torch
.
float32
and
x_s_out
.
device
==
x
.
device
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
num_head
,
num_seq
,
head_size
=
x
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
...
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
_per_tensor_quant_mla_fp8_stage1
[
grid
](
_per_tensor_quant_mla_fp8_stage1
[
grid
](
x
,
x
,
x_s
,
x_s
_out
,
head_size
,
head_size
,
x
.
stride
(
0
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
1
),
...
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
)
)
_per_tensor_quant_mla_fp8_stage2
[
grid
](
_per_tensor_quant_mla_fp8_stage2
[
grid
](
x
,
x
,
x_s
,
x_s
_out
,
x_q
,
x_q
,
num_seq
,
num_seq
,
head_size
,
head_size
,
...
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE
,
BLOCK_SIZE
,
)
)
return
x_q
,
x_s
return
x_q
,
x_s
_out
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
613b197e
...
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
...
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
...
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
...
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual
=
None
residual
=
None
hidden_states
,
residual
=
self
.
decoder
(
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
613b197e
...
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
...
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
...
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class
AttnForwardMethod
(
IntEnum
):
class
AttnForwardMethod
(
IntEnum
):
# Use multi-head attention
# Use multi-head attention
MHA
=
auto
()
MHA
=
auto
()
...
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
assert
(
...
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
hidden_states
,
forward_batch
positions
,
hidden_states
,
forward_batch
)
)
else
:
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal
(
def
forward_normal
(
self
,
self
,
...
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_input
=
hidden_states
.
new_empty
(
...
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
enable_rope_fusion
=
(
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
...
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
)
# Gather
# Gather
...
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
)
if
self
.
attn_tp_size
!=
1
:
if
self
.
attn_tp_size
!=
1
:
...
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
...
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size
=
len
(
self
.
layers
)
*
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder
.
set_current_layer
(
i
)
expert_distribution_recorder
.
set_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
if
residual
is
None
:
...
...
python/sglang/srt/utils.py
View file @
613b197e
...
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
...
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM"
,
"MistralForCausalLM"
,
}
}
return
architectures
[
0
]
in
default_archs
return
architectures
[
0
]
in
default_archs
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class
BumpAllocator
:
def
__init__
(
self
,
buffer_size
:
int
,
dtype
,
device
):
self
.
_buffer
=
torch
.
zeros
((
buffer_size
,),
dtype
=
dtype
,
device
=
device
)
self
.
_pointer
=
0
def
allocate
(
self
,
size
:
int
):
assert
self
.
_pointer
+
size
<=
len
(
self
.
_buffer
)
output
=
self
.
_buffer
[
self
.
_pointer
:
self
.
_pointer
+
size
]
self
.
_pointer
+=
size
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment