"vscode:/vscode.git/clone" did not exist on "181bf5bbde1b4f490287c0e7b679e0b31e912b14"
Unverified Commit fc701c80 authored by zofia's avatar zofia Committed by GitHub
Browse files

[XPU][MXFP4] add mxfp4 quant op for XPU (#39857)


Signed-off-by: default avatarZhu, Zufang <zufang.zhu@intel.com>
parent 68be0f85
...@@ -144,6 +144,46 @@ def _xpu_mxfp8_quantize_fake( ...@@ -144,6 +144,46 @@ def _xpu_mxfp8_quantize_fake(
return x.to(dtype), x_s.to(torch.float8_e8m0fnu) return x.to(dtype), x_s.to(torch.float8_e8m0fnu)
def _xpu_mxfp4_quantize_impl(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
MXFP4_BLOCK_SIZE = 32
eps = 1e-10
assert x.ndim == 2, "input must be 2-D"
assert x.shape[-1] % MXFP4_BLOCK_SIZE == 0, (
f"last dimension {x.shape[-1]} must be divisible by group_size "
f"{MXFP4_BLOCK_SIZE}"
)
assert x.is_contiguous(), "input groups must be contiguous"
M, N = x.shape
# Packed FP4 output: two nibbles per byte
x_q = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
x_s = torch.empty(M, N // MXFP4_BLOCK_SIZE, device=x.device, dtype=torch.float32)
torch.ops._C.per_token_group_quant_mxfp4(x, x_q, x_s, MXFP4_BLOCK_SIZE, eps)
x_q = x_q.view(torch.float4_e2m1fn_x2)
x_s = x_s.to(dtype=torch.float8_e8m0fnu, memory_format=torch.preserve_format)
return x_q, x_s
def _xpu_mxfp4_quantize_fake(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
MXFP4_BLOCK_SIZE = 32
M, N = x.shape
# Packed FP4 output: two nibbles per byte
x_q = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
x_s = torch.empty(M, N // MXFP4_BLOCK_SIZE, device=x.device, dtype=torch.float32)
x_q = x_q.view(torch.float4_e2m1fn_x2)
x_s = x_s.to(dtype=torch.float8_e8m0fnu, memory_format=torch.preserve_format)
return x_q, x_s
# Global flag to ensure ops are registered only once # Global flag to ensure ops are registered only once
_OPS_REGISTERED = False _OPS_REGISTERED = False
...@@ -555,6 +595,12 @@ class xpu_ops: ...@@ -555,6 +595,12 @@ class xpu_ops:
fake_impl=_xpu_mxfp8_quantize_fake, fake_impl=_xpu_mxfp8_quantize_fake,
) )
direct_register_custom_op(
op_name="xpu_mxfp4_quantize",
op_func=_xpu_mxfp4_quantize_impl,
fake_impl=_xpu_mxfp4_quantize_fake,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
......
...@@ -162,3 +162,7 @@ try: ...@@ -162,3 +162,7 @@ try:
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
except AttributeError as error: except AttributeError as error:
raise error raise error
def xpu_mxfp4_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.xpu_mxfp4_quantize(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment