Commit 625b0b5e authored by zhuwenwen's avatar zhuwenwen
Browse files

only support fp8 e4m3 on nmz

parent a3488ab0
......@@ -142,7 +142,10 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 dtype: {kv_cache_dtype}")
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment