Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
560ae963
Unverified
Commit
560ae963
authored
Dec 20, 2025
by
Yan Ma
Committed by
GitHub
Dec 20, 2025
Browse files
[XPU] enable fp8 online streaming quantization (#30944)
Signed-off-by:
Yan Ma
<
yan.ma@intel.com
>
parent
1501a407
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
107 deletions
+29
-107
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+12
-2
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+17
-105
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
560ae963
...
@@ -124,11 +124,13 @@ def get_fp8_moe_backend(
...
@@ -124,11 +124,13 @@ def get_fp8_moe_backend(
block_quant
:
bool
,
block_quant
:
bool
,
moe_parallel_config
:
FusedMoEParallelConfig
,
moe_parallel_config
:
FusedMoEParallelConfig
,
with_lora_support
:
bool
,
with_lora_support
:
bool
,
)
->
Fp8MoeBackend
:
)
->
Fp8MoeBackend
|
None
:
"""
"""
Select the primary FP8 MoE backend
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
Note: Shape-specific fallbacks may still occur at runtime.
"""
"""
if
current_platform
.
is_xpu
():
return
None
if
with_lora_support
:
if
with_lora_support
:
return
Fp8MoeBackend
.
TRITON
return
Fp8MoeBackend
.
TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
...
@@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
...
@@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
XPUFp8LinearMethod
(
fp8_config
)
return
XPUFp8LinearMethod
(
fp8_config
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped
(
prefix
=
prefix
,
ignored_layers
=
self
.
ignored_layers
,
fused_mapping
=
self
.
packed_modules_mapping
,
):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
return
XPUFp8MoEMethod
(
fp8_config
,
layer
)
return
XPUFp8MoEMethod
(
fp8_config
,
layer
)
elif
isinstance
(
layer
,
Attention
):
elif
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
return
Fp8KVCacheMethod
(
self
)
...
@@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
(
if
(
self
.
rocm_aiter_moe_enabled
current_platform
.
is_xpu
()
or
self
.
rocm_aiter_moe_enabled
or
self
.
use_marlin
or
self
.
use_marlin
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
):
):
...
...
vllm/model_executor/layers/quantization/ipex_quant.py
View file @
560ae963
...
@@ -6,13 +6,8 @@ from typing import Any, Optional
...
@@ -6,13 +6,8 @@ from typing import Any, Optional
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearBase
,
...
@@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
...
@@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods
,
QuantizationMethods
,
)
)
from
vllm.model_executor.layers.quantization.awq
import
AWQLinearMethod
from
vllm.model_executor.layers.quantization.awq
import
AWQLinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
,
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8Config
,
Fp8LinearMethod
,
Fp8OnlineMoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.utils
import
replace_parameter
maybe_create_device_identity
,
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MIN_IPEX_VERSION
=
"2.6.0"
MIN_IPEX_VERSION
=
"2.6.0"
...
@@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
...
@@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
quant_config
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# Update the layer with the new values.
# Update the layer with the new values.
layer
.
weight
=
P
arameter
(
q
weight
,
requires_grad
=
False
)
replace_p
arameter
(
layer
,
"
weight
"
,
qweight
.
data
)
layer
.
weight_scale
=
P
arameter
(
weight_scale
,
requires_grad
=
False
)
replace_p
arameter
(
layer
,
"
weight_scale
"
,
weight_scale
.
data
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
def
apply
(
def
apply
(
...
@@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
...
@@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
return
output
return
output
class
XPUFp8MoEMethod
(
F
used
MoEMethod
Base
):
class
XPUFp8MoEMethod
(
F
p8Online
MoEMethod
):
def
__init__
(
self
,
quant_config
:
Fp8Config
,
layer
:
torch
.
nn
.
Module
):
def
__init__
(
self
,
quant_config
:
Fp8Config
,
layer
:
torch
.
nn
.
Module
):
super
().
__init__
(
layer
.
moe_config
)
super
().
__init__
(
quant_config
,
layer
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT_SCALES
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
...
@@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
...
@@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
import
intel_extension_for_pytorch
as
ipex
import
intel_extension_for_pytorch
as
ipex
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment