Unverified Commit 40a36ccf authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm][Bugfix] Use platform specific FP8 dtype (#15717)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent ef608c37
...@@ -753,7 +753,7 @@ if triton.__version__ >= "2.1.0": ...@@ -753,7 +753,7 @@ if triton.__version__ >= "2.1.0":
assert (v_cache.dtype == torch.uint8) assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2": elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2 target_dtype = torch.float8_e5m2
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment