Commit 88411543 authored by zhuwenwen's avatar zhuwenwen
Browse files

add fp8 support on bw

parent 625b0b5e
...@@ -69,7 +69,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: ...@@ -69,7 +69,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if current_platform.is_rocm():
return True return True
return get_flash_attn_version() == 3 and \ return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9 current_platform.get_device_capability().major == 9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment