Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
484c5433
Commit
484c5433
authored
Nov 11, 2025
by
linhai1
Browse files
support fp8_e4m3.
parent
93eb92f8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
150 additions
and
17 deletions
+150
-17
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+118
-9
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+14
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+17
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
484c5433
...
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
triton
...
...
@@ -363,7 +363,7 @@ class DCUMLABackend(AttentionBackend):
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
scaling
:
float
,
k_scale
=
None
,
kv_cache_dtype
=
None
):
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
o
,
_
=
flash_mla_with_kvcache_quantization
(
q
=
reshape_q
,
...
...
@@ -375,7 +375,8 @@ class DCUMLABackend(AttentionBackend):
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
is_fp8_kvcache
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
return
o
...
...
@@ -412,14 +413,23 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
),
kv_cache_dtype
=
"fp8_e4m3"
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
...
@@ -474,14 +484,21 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
),
kv_cache_dtype
=
self
.
data_type
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
...
...
@@ -489,3 +506,95 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
if
topk
>
1
:
raise
ValueError
(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
DCUMLABackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
None
,
)
)
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
call_fn
:
Callable
,
):
assert
forward_batch
.
spec_info
is
not
None
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
484c5433
...
...
@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype
=
q
.
dtype
if
(
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
...
...
@@ -828,7 +829,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
):
k_descale
=
k_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v_descale
=
v_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
...
...
@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
@@ -856,10 +859,14 @@ class FlashAttentionBackend(AttentionBackend):
)
else
:
# MHA for extend part of sequence without attending prefix kv cache
# if layer.layer_id == 0:
# print("q.dtype, k.shape, v.shape, k.dtype, v.dtype, layer.k_scale.shape, layer.k_scale.dtype, layer.v_scale.shape, layer.v_scale.dtype, \n",
# q.dtype, k.shape, v.shape, k.dtype, v.dtype, )
# print("layer info: \n", layer)
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
484c5433
...
...
@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module):
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
self
.
attn_mha
.
layer_id
)
)
.
to
(
q
.
dtype
)
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
484c5433
...
...
@@ -46,6 +46,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"dcu_mla"
:
self
.
_create_dcumla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
...
...
@@ -69,6 +70,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"dcu_mla"
:
self
.
_create_dcumla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
...
...
@@ -149,6 +151,15 @@ class DraftBackendFactory:
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_dcumla_decode_backend
(
self
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
(
DCUMLAMultiStepDraftBackend
,
)
return
DCUMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
...
...
@@ -224,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment