Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
613b197e
"vscode:/vscode.git/clone" did not exist on "d1b2f4a18e0cae06b0a2288763c67b8371e65d9e"
Unverified
Commit
613b197e
authored
Apr 20, 2025
by
fzyzcjy
Committed by
GitHub
Apr 19, 2025
Browse files
Remove one kernel in per_tensor_quant_mla_fp8 (#5549)
parent
d58e3544
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
18 deletions
+62
-18
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+9
-7
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+8
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+32
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
613b197e
...
@@ -58,10 +58,8 @@ if _is_cuda:
...
@@ -58,10 +58,8 @@ if _is_cuda:
):
):
_enable_jit_deepgemm
=
True
_enable_jit_deepgemm
=
True
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
supports_custom_op
():
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
def
deep_gemm_fp8_fp8_bf16_nt
(
...
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
...
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
def
per_tensor_quant_mla_fp8
(
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
x
:
torch
.
Tensor
,
x_s_out
:
torch
.
Tensor
,
eps
:
float
=
1e-12
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
This function quantizes input values to float8 values with tensor-wise quantization
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
and specialized for mla absorbed case.
"""
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
assert
(
x_s_out
.
shape
==
(
1
,)
and
x_s_out
.
dtype
==
torch
.
float32
and
x_s_out
.
device
==
x
.
device
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
num_head
,
num_seq
,
head_size
=
x
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
...
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
_per_tensor_quant_mla_fp8_stage1
[
grid
](
_per_tensor_quant_mla_fp8_stage1
[
grid
](
x
,
x
,
x_s
,
x_s
_out
,
head_size
,
head_size
,
x
.
stride
(
0
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
1
),
...
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
)
)
_per_tensor_quant_mla_fp8_stage2
[
grid
](
_per_tensor_quant_mla_fp8_stage2
[
grid
](
x
,
x
,
x_s
,
x_s
_out
,
x_q
,
x_q
,
num_seq
,
num_seq
,
head_size
,
head_size
,
...
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
...
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE
,
BLOCK_SIZE
,
)
)
return
x_q
,
x_s
return
x_q
,
x_s
_out
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
613b197e
...
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
...
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
...
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
...
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual
=
None
residual
=
None
hidden_states
,
residual
=
self
.
decoder
(
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
613b197e
...
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
...
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
...
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class
AttnForwardMethod
(
IntEnum
):
class
AttnForwardMethod
(
IntEnum
):
# Use multi-head attention
# Use multi-head attention
MHA
=
auto
()
MHA
=
auto
()
...
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
assert
(
...
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
hidden_states
,
forward_batch
positions
,
hidden_states
,
forward_batch
)
)
else
:
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal
(
def
forward_normal
(
self
,
self
,
...
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_input
=
hidden_states
.
new_empty
(
...
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
enable_rope_fusion
=
(
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
...
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp8
(
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
q_nope
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
attn_output
.
transpose
(
0
,
1
),
zero_allocator
.
allocate
(
1
),
dtype
=
torch
.
float8_e4m3fn
,
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
)
# Gather
# Gather
...
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
)
if
self
.
attn_tp_size
!=
1
:
if
self
.
attn_tp_size
!=
1
:
...
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
...
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size
=
len
(
self
.
layers
)
*
2
,
dtype
=
torch
.
float32
,
device
=
input_ids
.
device
,
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder
.
set_current_layer
(
i
)
expert_distribution_recorder
.
set_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
if
residual
is
None
:
...
...
python/sglang/srt/utils.py
View file @
613b197e
...
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
...
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM"
,
"MistralForCausalLM"
,
}
}
return
architectures
[
0
]
in
default_archs
return
architectures
[
0
]
in
default_archs
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class
BumpAllocator
:
def
__init__
(
self
,
buffer_size
:
int
,
dtype
,
device
):
self
.
_buffer
=
torch
.
zeros
((
buffer_size
,),
dtype
=
dtype
,
device
=
device
)
self
.
_pointer
=
0
def
allocate
(
self
,
size
:
int
):
assert
self
.
_pointer
+
size
<=
len
(
self
.
_buffer
)
output
=
self
.
_buffer
[
self
.
_pointer
:
self
.
_pointer
+
size
]
self
.
_pointer
+=
size
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment