Unverified Commit 55842a8d authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

[XPU]fake impl for xpu fp8_gemm (#39984)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 1f45e837
......@@ -22,6 +22,23 @@ else:
except ImportError:
from torch.library import impl_abstract as register_fake
if hasattr(torch.ops._xpu_C, "fp8_gemm"):
@register_fake("_xpu_C::fp8_gemm")
def _fp8_gemm_fake(
q_input: torch.Tensor,
q_weight: torch.Tensor,
out_dtype: torch.dtype,
input_scales: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
input_2d = q_input.view(-1, q_input.shape[-1])
M = input_2d.size(0)
N = q_weight.size(1)
return torch.empty((M, N), dtype=out_dtype, device=q_input.device)
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
@register_fake("_xpu_C::fp8_gemm_w8a16")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment