Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
560ae963
Unverified
Commit
560ae963
authored
Dec 20, 2025
by
Yan Ma
Committed by
GitHub
Dec 20, 2025
Browse files
[XPU] enable fp8 online streaming quantization (#30944)
Signed-off-by:
Yan Ma
<
yan.ma@intel.com
>
parent
1501a407
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
107 deletions
+29
-107
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+12
-2
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+17
-105
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
560ae963
...
...
@@ -124,11 +124,13 @@ def get_fp8_moe_backend(
block_quant
:
bool
,
moe_parallel_config
:
FusedMoEParallelConfig
,
with_lora_support
:
bool
,
)
->
Fp8MoeBackend
:
)
->
Fp8MoeBackend
|
None
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if
current_platform
.
is_xpu
():
return
None
if
with_lora_support
:
return
Fp8MoeBackend
.
TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
...
...
@@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
return
UnquantizedLinearMethod
()
return
XPUFp8LinearMethod
(
fp8_config
)
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped
(
prefix
=
prefix
,
ignored_layers
=
self
.
ignored_layers
,
fused_mapping
=
self
.
packed_modules_mapping
,
):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
return
XPUFp8MoEMethod
(
fp8_config
,
layer
)
elif
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
...
...
@@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
(
self
.
rocm_aiter_moe_enabled
current_platform
.
is_xpu
()
or
self
.
rocm_aiter_moe_enabled
or
self
.
use_marlin
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
):
...
...
vllm/model_executor/layers/quantization/ipex_quant.py
View file @
560ae963
...
...
@@ -6,13 +6,8 @@ from typing import Any, Optional
import
torch
from
packaging
import
version
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
...
...
@@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods
,
)
from
vllm.model_executor.layers.quantization.awq
import
AWQLinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
,
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8Config
,
Fp8LinearMethod
,
Fp8OnlineMoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
maybe_create_device_identity
,
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
MIN_IPEX_VERSION
=
"2.6.0"
...
...
@@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
super
().
__init__
(
quant_config
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# Update the layer with the new values.
layer
.
weight
=
P
arameter
(
q
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
P
arameter
(
weight_scale
,
requires_grad
=
False
)
replace_p
arameter
(
layer
,
"
weight
"
,
qweight
.
data
)
replace_p
arameter
(
layer
,
"
weight_scale
"
,
weight_scale
.
data
)
layer
.
input_scale
=
None
def
apply
(
...
...
@@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
return
output
class
XPUFp8MoEMethod
(
F
used
MoEMethod
Base
):
class
XPUFp8MoEMethod
(
F
p8Online
MoEMethod
):
def
__init__
(
self
,
quant_config
:
Fp8Config
,
layer
:
torch
.
nn
.
Module
):
super
().
__init__
(
layer
.
moe_config
)
super
().
__init__
(
quant_config
,
layer
)
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT_SCALES
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
...
...
@@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
import
intel_extension_for_pytorch
as
ipex
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment