Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c9bf3877
"vscode:/vscode.git/clone" did not exist on "0db19da01f2322485e6e2fe84cec39869e0f35cc"
Unverified
Commit
c9bf3877
authored
Aug 20, 2025
by
Yichen Yan
Committed by
GitHub
Aug 20, 2025
Browse files
Reduce overhead for fa by not calling heavy CUDA property check (#7375)
parent
de2dd738
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+5
-3
sgl-kernel/tests/test_flash_attention.py
sgl-kernel/tests/test_flash_attention.py
+2
-2
No files found.
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
c9bf3877
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
functools
import
lru_cache
from
typing
import
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -9,6 +10,7 @@ except:
...
@@ -9,6 +10,7 @@ except:
raise
ImportError
(
"Can not import sgl_kernel. Please check your installation."
)
raise
ImportError
(
"Can not import sgl_kernel. Please check your installation."
)
@
lru_cache
(
maxsize
=
1
)
def
is_fa3_supported
(
device
=
None
)
->
bool
:
def
is_fa3_supported
(
device
=
None
)
->
bool
:
# There some fa3 FYI
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# FA3 can fail without a enough shared memory for a some shapes, such as higher
...
@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool:
...
@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return
(
return
(
torch
.
version
.
cuda
>=
"12.3"
)
and
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
or
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
8
or
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
8
)
and
(
torch
.
version
.
cuda
>=
"12.3"
)
)
def
maybe_contiguous
(
x
):
def
maybe_contiguous
(
x
):
...
...
sgl-kernel/tests/test_flash_attention.py
View file @
c9bf3877
...
@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool:
...
@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return
(
return
(
torch
.
version
.
cuda
>=
"12.3"
)
and
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
or
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
8
or
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
8
)
and
(
torch
.
version
.
cuda
>=
"12.3"
)
)
DISABLE_BACKWARD
=
True
DISABLE_BACKWARD
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment