Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc701c80
Unverified
Commit
fc701c80
authored
Apr 15, 2026
by
zofia
Committed by
GitHub
Apr 15, 2026
Browse files
[XPU][MXFP4] add mxfp4 quant op for XPU (#39857)
Signed-off-by:
Zhu, Zufang
<
zufang.zhu@intel.com
>
parent
68be0f85
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
0 deletions
+50
-0
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+46
-0
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+4
-0
No files found.
vllm/_xpu_ops.py
View file @
fc701c80
...
@@ -144,6 +144,46 @@ def _xpu_mxfp8_quantize_fake(
...
@@ -144,6 +144,46 @@ def _xpu_mxfp8_quantize_fake(
return
x
.
to
(
dtype
),
x_s
.
to
(
torch
.
float8_e8m0fnu
)
return
x
.
to
(
dtype
),
x_s
.
to
(
torch
.
float8_e8m0fnu
)
def
_xpu_mxfp4_quantize_impl
(
x
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
MXFP4_BLOCK_SIZE
=
32
eps
=
1e-10
assert
x
.
ndim
==
2
,
"input must be 2-D"
assert
x
.
shape
[
-
1
]
%
MXFP4_BLOCK_SIZE
==
0
,
(
f
"last dimension
{
x
.
shape
[
-
1
]
}
must be divisible by group_size "
f
"
{
MXFP4_BLOCK_SIZE
}
"
)
assert
x
.
is_contiguous
(),
"input groups must be contiguous"
M
,
N
=
x
.
shape
# Packed FP4 output: two nibbles per byte
x_q
=
torch
.
empty
(
M
,
N
//
2
,
device
=
x
.
device
,
dtype
=
torch
.
uint8
)
x_s
=
torch
.
empty
(
M
,
N
//
MXFP4_BLOCK_SIZE
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
per_token_group_quant_mxfp4
(
x
,
x_q
,
x_s
,
MXFP4_BLOCK_SIZE
,
eps
)
x_q
=
x_q
.
view
(
torch
.
float4_e2m1fn_x2
)
x_s
=
x_s
.
to
(
dtype
=
torch
.
float8_e8m0fnu
,
memory_format
=
torch
.
preserve_format
)
return
x_q
,
x_s
def
_xpu_mxfp4_quantize_fake
(
x
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
MXFP4_BLOCK_SIZE
=
32
M
,
N
=
x
.
shape
# Packed FP4 output: two nibbles per byte
x_q
=
torch
.
empty
(
M
,
N
//
2
,
device
=
x
.
device
,
dtype
=
torch
.
uint8
)
x_s
=
torch
.
empty
(
M
,
N
//
MXFP4_BLOCK_SIZE
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
x_q
=
x_q
.
view
(
torch
.
float4_e2m1fn_x2
)
x_s
=
x_s
.
to
(
dtype
=
torch
.
float8_e8m0fnu
,
memory_format
=
torch
.
preserve_format
)
return
x_q
,
x_s
# Global flag to ensure ops are registered only once
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
_OPS_REGISTERED
=
False
...
@@ -555,6 +595,12 @@ class xpu_ops:
...
@@ -555,6 +595,12 @@ class xpu_ops:
fake_impl
=
_xpu_mxfp8_quantize_fake
,
fake_impl
=
_xpu_mxfp8_quantize_fake
,
)
)
direct_register_custom_op
(
op_name
=
"xpu_mxfp4_quantize"
,
op_func
=
_xpu_mxfp4_quantize_impl
,
fake_impl
=
_xpu_mxfp4_quantize_fake
,
)
_OPS_REGISTERED
=
True
_OPS_REGISTERED
=
True
...
...
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
View file @
fc701c80
...
@@ -162,3 +162,7 @@ try:
...
@@ -162,3 +162,7 @@ try:
quant_dequant_mxfp4
=
torch
.
ops
.
vllm
.
quant_dequant_mxfp4
quant_dequant_mxfp4
=
torch
.
ops
.
vllm
.
quant_dequant_mxfp4
except
AttributeError
as
error
:
except
AttributeError
as
error
:
raise
error
raise
error
def
xpu_mxfp4_quantize
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
xpu_mxfp4_quantize
(
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment