"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "8384badc69c219c51918c70ac5e7eb8528253b3e"
Commit 59b01a00 authored by linhai1's avatar linhai1
Browse files

support fp8_e4m3 and fp8_e5m2.

parent 34f0ebb1
...@@ -394,9 +394,11 @@ class DCUMLABackend(AttentionBackend): ...@@ -394,9 +394,11 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None), getattr(torch, "float8_e5m2fnuz", None),
): ):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz: if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3" kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz: elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2" kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device) k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode( o = self._call_fp8_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment