Commit 59b01a00 authored by linhai1's avatar linhai1
Browse files

support fp8_e4m3 and fp8_e5m2.

parent 34f0ebb1
......@@ -394,9 +394,11 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz:
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz:
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment