Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f37b05b5
Commit
f37b05b5
authored
Nov 12, 2025
by
linhai1
Browse files
modify codes with performance issues.
parent
59b01a00
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
38 deletions
+36
-38
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+13
-24
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+7
-10
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+9
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-2
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
f37b05b5
...
...
@@ -89,6 +89,7 @@ class DCUMLABackend(AttentionBackend):
self
.
q_data_type
=
model_runner
.
dtype
self
.
device
=
model_runner
.
device
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
...
...
@@ -388,26 +389,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fn
:
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2
:
else
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
self
.
k_scale
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
)
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
else
:
...
...
@@ -460,26 +455,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fn
:
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2
:
else
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
self
.
k_scale
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
)
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
else
:
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
f37b05b5
...
...
@@ -695,7 +695,6 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype
=
q
.
dtype
if
(
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
...
...
@@ -705,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
#
q = q.to(self.kv_cache_dtype)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
causal
=
True
...
...
@@ -830,8 +829,6 @@ class FlashAttentionBackend(AttentionBackend):
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
k_descale
=
k_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v_descale
=
v_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
...
...
@@ -845,9 +842,9 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
@@ -859,9 +856,9 @@ class FlashAttentionBackend(AttentionBackend):
)
else
:
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
f37b05b5
...
...
@@ -1301,6 +1301,15 @@ class MLATokenToKVPool(KVCache):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
]
def
get_key_buffer_DeepSeekV2
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
and
self
.
dtype
not
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
],
self
.
dtype
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f37b05b5
...
...
@@ -1624,12 +1624,14 @@ class ModelRunner:
self
.
kv_cache_dtype
=
self
.
dtype
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
:
if
_is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e5m2fnuz
# self.kv_cache_dtype = torch.float8_e5m2fnuz
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
_is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fnuz
# self.kv_cache_dtype = torch.float8_e4m3fnuz
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
elif
self
.
server_args
.
kv_cache_dtype
in
(
"bf16"
,
"bfloat16"
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f37b05b5
...
...
@@ -2294,12 +2294,13 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
.
set_prefix_chunk_idx
(
i
)
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
latent_cache_buf
,
dtype
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
_DeepSeekV2
(
self
.
attn_mha
.
layer_id
)
.
to
(
q
.
dtype
)
)
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
.
view
(
dtype
)
.
to
(
q
.
dtype
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment