Unverified Commit e3908bf5 authored by Yizhi Wang's avatar Yizhi Wang Committed by GitHub
Browse files

[fix] handle empty tensor in per_token_cast_back (#360)

parent d9767ce0
...@@ -52,6 +52,8 @@ def per_token_cast_to_fp8(x: torch.Tensor): ...@@ -52,6 +52,8 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_fp8.numel() == 0:
return x_fp8.to(torch.bfloat16)
if x_scales.dtype == torch.int: if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float) x_scales = x_scales.view(dtype=torch.float)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment