Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
50f7ea0f
Commit
50f7ea0f
authored
Nov 11, 2025
by
linhai1
Browse files
support fp8_e4m3.
parents
484c5433
6741925c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
40 deletions
+19
-40
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+19
-35
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+0
-5
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
50f7ea0f
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
flashmla_metadata
=
flashmla_metadata
self
.
num_splits
=
num_splits
self
.
block_kv_indices
=
block_kv_indices
class
DCUMLABackend
(
AttentionBackend
):
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
def
__init__
(
...
@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
if
not
skip_prefill
:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
model_runner
,
...
@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits_t
,
block_kv_indices
)
)
else
:
else
:
# prefill/extend用triton backend -> 改用flash attn
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata(forward_batch)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
]
]
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -413,6 +394,10 @@ class DCUMLABackend(AttentionBackend):
...
@@ -413,6 +394,10 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
:
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
o
=
self
.
_call_fp8_decode
(
reshape_q
,
reshape_q
,
...
@@ -421,7 +406,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -421,7 +406,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
),
k_scale
.
to
(
torch
.
float32
),
kv_cache_dtype
=
"fp8_e4m3"
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
else
:
else
:
o
=
self
.
_call_decode
(
o
=
self
.
_call_decode
(
...
@@ -442,7 +427,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -442,7 +427,6 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -455,11 +439,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -455,11 +439,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
):
):
# flash_attn不支持fp8,fp8无法正常执行extend
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
)
)
...
@@ -484,6 +464,10 @@ class DCUMLABackend(AttentionBackend):
...
@@ -484,6 +464,10 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
:
k_cache_reshaped
=
k_cache_reshaped
.
view
(
torch
.
float8_e4m3fn
)
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
:
k_cache_reshaped
=
k_cache_reshaped
.
view
(
torch
.
float8_e5m2
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
o
=
self
.
_call_fp8_decode
(
reshape_q
,
reshape_q
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
50f7ea0f
...
@@ -858,11 +858,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -858,11 +858,6 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
# MHA for extend part of sequence without attending prefix kv cache
# if layer.layer_id == 0:
# print("q.dtype, k.shape, v.shape, k.dtype, v.dtype, layer.k_scale.shape, layer.k_scale.dtype, layer.v_scale.shape, layer.v_scale.dtype, \n",
# q.dtype, k.shape, v.shape, k.dtype, v.dtype, )
# print("layer info: \n", layer)
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
).
to
(
data_dtype
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
).
to
(
data_dtype
),
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_dtype
),
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_dtype
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment