Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
59b01a00
"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "8384badc69c219c51918c70ac5e7eb8528253b3e"
Commit
59b01a00
authored
Nov 11, 2025
by
linhai1
Browse files
support fp8_e4m3 and fp8_e5m2.
parent
34f0ebb1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+4
-2
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
59b01a00
...
@@ -394,9 +394,11 @@ class DCUMLABackend(AttentionBackend):
...
@@ -394,9 +394,11 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
:
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fn
:
kv_cache_dtype
=
"fp8_e4m3"
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
:
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2
:
kv_cache_dtype
=
"fp8_e5m2"
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
o
=
self
.
_call_fp8_decode
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment