Unverified Commit 4f288113 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: update flash attn (#5308)

parent 136b8e6a
...@@ -10,15 +10,9 @@ except: ...@@ -10,15 +10,9 @@ except:
def is_fa3_supported(device=None) -> bool: def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently # now sgl-kernel only build fa3 for sm90a && cuda >= 12.3
# only 8.0 and 8.7 have enough shared memory for all shapes return (torch.cuda.get_device_capability(device)[0] == 9) and (
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x torch.version.cuda >= "12.3"
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
return (
(torch.cuda.get_device_capability(device)[0] >= 9)
and (torch.version.cuda >= "12.4")
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
) )
...@@ -144,7 +138,7 @@ def flash_attn_with_kvcache( ...@@ -144,7 +138,7 @@ def flash_attn_with_kvcache(
""" """
if not is_fa3_supported(): if not is_fa3_supported():
raise NotImplementedError( raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above" "flash_attn at sgl-kernel is only supported on sm90 and cu123 above"
) )
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment