Commit 04ea3540 authored by linhai1's avatar linhai1
Browse files

support fp8_e4m3.

parent 50f7ea0f
......@@ -465,9 +465,9 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz:
k_cache_reshaped = k_cache_reshaped.view(torch.float8_e4m3fn)
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz:
k_cache_reshaped = k_cache_reshaped.view(torch.float8_e5m2)
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode(
reshape_q,
......@@ -476,7 +476,7 @@ class DCUMLABackend(AttentionBackend):
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
kv_cache_dtype=self.data_type,
kv_cache_dtype=kv_cache_dtype,
)
else:
o = self._call_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment