Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad2b1277
Unverified
Commit
ad2b1277
authored
Apr 16, 2026
by
Asaf Gardin
Committed by
GitHub
Apr 16, 2026
Browse files
[Quantization] Consolidate experts_int8 with fp8 online quantization (#38463)
Signed-off-by:
Josephasafg
<
ajgard7@gmail.com
>
parent
b897f00c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
403 additions
and
313 deletions
+403
-313
tests/quantization/test_experts_int8.py
tests/quantization/test_experts_int8.py
+0
-1
vllm/config/quantization.py
vllm/config/quantization.py
+4
-0
vllm/model_executor/layers/fused_moe/oracle/int8.py
vllm/model_executor/layers/fused_moe/oracle/int8.py
+84
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+1
-1
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+11
-157
vllm/model_executor/layers/quantization/online/base.py
vllm/model_executor/layers/quantization/online/base.py
+16
-2
vllm/model_executor/layers/quantization/online/fp8.py
vllm/model_executor/layers/quantization/online/fp8.py
+6
-152
vllm/model_executor/layers/quantization/online/int8.py
vllm/model_executor/layers/quantization/online/int8.py
+109
-0
vllm/model_executor/layers/quantization/online/moe_base.py
vllm/model_executor/layers/quantization/online/moe_base.py
+172
-0
No files found.
tests/quantization/test_experts_int8.py
View file @
ad2b1277
...
...
@@ -38,6 +38,5 @@ def test_model_experts_int8_startup(
dtype
=
dtype
,
enforce_eager
=
True
,
quantization
=
"experts_int8"
,
allow_deprecated_quantization
=
True
,
)
as
vllm_model
:
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm/config/quantization.py
View file @
ad2b1277
...
...
@@ -19,6 +19,10 @@ class OnlineQuantScheme(Enum):
# blocks of 128x128 elements (popularized by DeepSeek)
FP8_PER_BLOCK
=
"fp8_per_block"
# int8, weight-only per-channel quantization for MoE expert weights.
# Linear layers remain unquantized.
INT8_PER_CHANNEL_WEIGHT_ONLY
=
"int8_per_channel_weight_only"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
...
...
vllm/model_executor/layers/fused_moe/oracle/int8.py
0 → 100644
View file @
ad2b1277
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
int8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
logger
=
init_logger
(
__name__
)
def
select_int8_moe_backend
(
config
:
FusedMoEConfig
,
)
->
type
[
mk
.
FusedMoEExperts
]:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
supported
,
reason
=
TritonExperts
.
is_supported_config
(
TritonExperts
,
config
,
None
,
None
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
if
not
supported
:
raise
ValueError
(
f
"INT8 Triton MoE backend does not support the "
f
"deployment configuration:
{
reason
}
"
)
logger
.
info_once
(
"Using Triton INT8 MoE backend"
,
scope
=
"local"
)
return
TritonExperts
def
make_int8_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
)
->
FusedMoEQuantConfig
:
return
int8_w8a16_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
None
,
w2_zp
=
None
,
)
def
make_int8_moe_kernel
(
moe_quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
=
None
,
)
->
mk
.
FusedMoEKernel
:
prepare_finalize
=
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
moe_quant_config
,
routing_tables
=
routing_tables
,
allow_new_interface
=
True
,
)
assert
prepare_finalize
is
not
None
logger
.
info_once
(
"Using %s"
,
prepare_finalize
.
__class__
.
__name__
,
scope
=
"local"
)
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
moe_quant_config
,
)
return
mk
.
FusedMoEKernel
(
prepare_finalize
,
experts
,
shared_experts
=
shared_experts
,
inplace
=
not
moe_config
.
disable_inplace
,
)
vllm/model_executor/layers/quantization/__init__.py
View file @
ad2b1277
...
...
@@ -40,6 +40,7 @@ QuantizationMethods = Literal[
# shorthand for creating a more complicated online quant config object
"fp8_per_tensor"
,
"fp8_per_block"
,
"int8_per_channel_weight_only"
,
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
...
...
@@ -47,7 +48,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
"tpu_int8"
,
"fbgemm_fp8"
,
"fp_quant"
,
"experts_int8"
,
]
# The customized quantization methods which will be added to this dict.
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
ad2b1277
...
...
@@ -5,27 +5,25 @@ from typing import Any
import
torch
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEConfig
,
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.online.int8
import
(
Int8OnlineMoEMethod
,
)
class
ExpertsInt8Config
(
QuantizationConfig
):
"""Config class for Int8 experts quantization."""
"""Online int8 quantization for MoE expert weights.
Linear layers are left unquantized.
Backward-compatible config for ``--quantization experts_int8``.
Prefer ``--quantization int8_per_channel``
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -56,149 +54,5 @@ class ExpertsInt8Config(QuantizationConfig):
if
isinstance
(
layer
,
LinearBase
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
ExpertsInt8MoEMethod
(
self
,
layer
.
moe_config
)
return
Int8OnlineMoEMethod
(
layer
=
layer
)
return
None
class
ExpertsInt8MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
ExpertsInt8Config
,
moe
:
FusedMoEConfig
,
):
super
().
__init__
(
moe
)
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
int8_dtype
=
torch
.
int8
assert
"weight_loader"
in
extra_weight_attrs
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
wrapped_weight_loader
=
ExpertsInt8MoEMethod
.
quantizing_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
"weight_loader"
]
=
wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
int8_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
int8_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
int8_w8a16_moe_quant_config
(
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
w1_zp
=
None
,
w2_zp
=
None
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
not
self
.
moe
.
disable_inplace
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
)
@
staticmethod
def
quantizing_weight_loader
(
layer
,
weight_loader
):
def
quantize_and_call_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
layer
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
device
=
get_tp_group
().
device
loaded_weight
=
loaded_weight
.
to
(
device
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
"w1"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
0
:
shard_size
].
copy_
(
scales
[:,
0
])
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
"w3"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
shard_size
:
2
*
shard_size
].
copy_
(
scales
[:,
0
]
)
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
"w2"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[:,
shard
])
layer
.
w2_scale
.
data
[
expert_id
,
:].
copy_
(
scales
[:,
0
])
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
quantize_and_call_weight_loader
def
quantize_in_place_and_get_scales
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
vmax
=
torch
.
iinfo
(
torch
.
int8
).
max
scales
=
torch
.
max
(
torch
.
abs
(
weight
),
dim
=
1
,
keepdim
=
True
)[
0
]
/
vmax
weight
.
div_
(
scales
)
weight
.
round_
()
weight
.
clamp_
(
-
vmax
,
vmax
)
return
scales
vllm/model_executor/layers/quantization/online/base.py
View file @
ad2b1277
...
...
@@ -9,6 +9,7 @@ from vllm.config.quantization import (
OnlineQuantizationConfigArgs
,
OnlineQuantScheme
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
)
...
...
@@ -33,6 +34,11 @@ from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerTensorOnlineLinearMethod
,
Fp8PerTensorOnlineMoEMethod
,
)
from
vllm.model_executor.layers.quantization.online.int8
import
(
Int8OnlineMoEMethod
,
)
logger
=
init_logger
(
__name__
)
class
OnlineQuantizationConfig
(
QuantizationConfig
):
...
...
@@ -96,7 +102,13 @@ class OnlineQuantizationConfig(QuantizationConfig):
return
UnquantizedLinearMethod
()
linear_scheme
=
self
.
args
.
linear_scheme_override
or
self
.
args
.
global_scheme
if
linear_scheme
==
OnlineQuantScheme
.
FP8_PER_BLOCK
:
if
linear_scheme
==
OnlineQuantScheme
.
INT8_PER_CHANNEL_WEIGHT_ONLY
:
logger
.
warning_once
(
"INT8 online quantization only quantizes MoE expert "
"weights. linear layers remain in full precision."
)
return
UnquantizedLinearMethod
()
elif
linear_scheme
==
OnlineQuantScheme
.
FP8_PER_BLOCK
:
return
Fp8PerBlockOnlineLinearMethod
()
else
:
return
Fp8PerTensorOnlineLinearMethod
()
...
...
@@ -109,7 +121,9 @@ class OnlineQuantizationConfig(QuantizationConfig):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
moe_scheme
=
self
.
args
.
moe_scheme_override
or
self
.
args
.
global_scheme
if
moe_scheme
==
OnlineQuantScheme
.
FP8_PER_BLOCK
:
if
moe_scheme
==
OnlineQuantScheme
.
INT8_PER_CHANNEL_WEIGHT_ONLY
:
return
Int8OnlineMoEMethod
(
layer
=
layer
)
elif
moe_scheme
==
OnlineQuantScheme
.
FP8_PER_BLOCK
:
return
Fp8PerBlockOnlineMoEMethod
(
layer
=
layer
)
else
:
return
Fp8PerTensorOnlineMoEMethod
(
layer
=
layer
)
...
...
vllm/model_executor/layers/quantization/online/fp8.py
View file @
ad2b1277
...
...
@@ -10,7 +10,6 @@ if TYPE_CHECKING:
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
Fp8MoeBackend
...
...
@@ -19,15 +18,15 @@ import vllm.envs as envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.kernels.linear
import
init_fp8_linear_kernel
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
)
from
vllm.model_executor.layers.quantization.online.moe_base
import
(
OnlineMoEMethodBase
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
create_fp8_quant_key
,
...
...
@@ -44,7 +43,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing
,
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
per_block_cast_to_fp8
...
...
@@ -268,21 +267,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
# ---------------------------------------------------------------------------
class
_Fp8OnlineMoEBase
(
Fused
MoEMethodBase
):
class
_Fp8OnlineMoEBase
(
Online
MoEMethodBase
):
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device
:
bool
=
True
# Declared here for mypy; actual values are set in __init__.
fp8_backend
:
"Fp8MoeBackend"
experts_cls
:
"type[mk.FusedMoEExperts] | None"
weight_scale_name
:
str
weight_block_size
:
list
[
int
]
|
None
moe
:
"FusedMoEConfig"
is_monolithic
:
bool
moe_quant_config
:
"FusedMoEQuantConfig | None"
moe_kernel
:
"mk.FusedMoEKernel | None"
def
__init__
(
self
,
...
...
@@ -313,77 +306,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
allow_vllm_cutlass
=
False
,
)
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
device
=
"meta"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# BIASES (for models like GPT-OSS that have biased MoE)
if
self
.
moe
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
initialize_online_processing
(
layer
)
def
_setup_kernel
(
self
,
layer
:
"FusedMoE"
,
...
...
@@ -430,15 +352,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
shared_experts
=
layer
.
shared_experts
,
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
"mk.FusedMoEPrepareAndFinalizeModular | None"
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
uses the new modular kernel "
"initialization logic. This function should not be called."
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"FusedMoEQuantConfig"
:
...
...
@@ -460,68 +373,9 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
block_shape
=
self
.
weight_block_size
,
)
# Inject biases into the quant config if the model has them
# (e.g. GPT-OSS biased MoE)
if
quant_config
is
not
None
and
self
.
moe
.
has_bias
:
w13_bias
=
getattr
(
layer
,
"w13_bias"
,
None
)
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
)
if
w13_bias
is
not
None
:
quant_config
.
_w1
.
bias
=
w13_bias
if
w2_bias
is
not
None
:
quant_config
.
_w2
.
bias
=
w2_bias
self
.
_maybe_inject_biases
(
quant_config
,
layer
)
return
quant_config
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
def
apply_monolithic
(
self
,
layer
:
"FusedMoE"
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
routed_scaling_factor
=
layer
.
routed_scaling_factor
,
)
def
apply
(
self
,
layer
:
"FusedMoE"
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
shared_experts_input
=
shared_experts_input
,
)
class
Fp8PerTensorOnlineMoEMethod
(
_Fp8OnlineMoEBase
):
"""Online tensorwise FP8 MoE quantization.
...
...
vllm/model_executor/layers/quantization/online/int8.py
0 → 100644
View file @
ad2b1277
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
import
torch
from
torch.nn
import
Module
if
TYPE_CHECKING
:
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.oracle.int8
import
(
make_int8_moe_kernel
,
make_int8_moe_quant_config
,
select_int8_moe_backend
,
)
from
vllm.model_executor.layers.quantization.online.moe_base
import
(
OnlineMoEMethodBase
,
)
from
vllm.model_executor.utils
import
replace_parameter
class
Int8OnlineMoEMethod
(
OnlineMoEMethodBase
):
"""Online per-channel INT8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-row to int8 during loading.
"""
def
__init__
(
self
,
*
,
layer
:
torch
.
nn
.
Module
,
):
super
().
__init__
(
layer
.
moe_config
)
self
.
experts_cls
:
type
[
mk
.
FusedMoEExperts
]
=
select_int8_moe_backend
(
config
=
self
.
moe
,
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
self
.
_quantize_weights
(
layer
)
self
.
_setup_kernel
(
layer
)
layer
.
_already_called_process_weights_after_loading
=
True
def
_quantize_weights
(
self
,
layer
:
Module
)
->
None
:
vmax
=
torch
.
iinfo
(
torch
.
int8
).
max
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
torch
.
int8
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
torch
.
int8
)
w13_scale
=
torch
.
zeros
(
layer
.
num_experts
,
layer
.
w13_weight
.
shape
[
1
],
device
=
w13
.
device
,
dtype
=
torch
.
float32
,
)
w2_scale
=
torch
.
zeros
(
layer
.
num_experts
,
layer
.
w2_weight
.
shape
[
1
],
device
=
w2
.
device
,
dtype
=
torch
.
float32
,
)
for
expert
in
range
(
layer
.
local_num_experts
):
# w13: per-row quantization over hidden_size dim
w
=
layer
.
w13_weight
[
expert
,
:,
:]
scales
=
w
.
abs
().
amax
(
dim
=
1
)
/
vmax
q
=
w
.
div
(
scales
.
unsqueeze
(
1
)).
round
().
clamp
(
-
vmax
,
vmax
)
w13
[
expert
,
:,
:]
=
q
.
to
(
torch
.
int8
)
w13_scale
[
expert
,
:]
=
scales
# w2: per-row quantization over intermediate_size dim
w
=
layer
.
w2_weight
[
expert
,
:,
:]
scales
=
w
.
abs
().
amax
(
dim
=
1
)
/
vmax
q
=
w
.
div
(
scales
.
unsqueeze
(
1
)).
round
().
clamp
(
-
vmax
,
vmax
)
w2
[
expert
,
:,
:]
=
q
.
to
(
torch
.
int8
)
w2_scale
[
expert
,
:]
=
scales
replace_parameter
(
layer
,
"w13_weight"
,
w13
)
replace_parameter
(
layer
,
"w2_weight"
,
w2
)
replace_parameter
(
layer
,
"w13_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_scale"
,
w2_scale
)
def
_setup_kernel
(
self
,
layer
:
"FusedMoE"
)
->
None
:
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
assert
self
.
moe_quant_config
is
not
None
assert
self
.
experts_cls
is
not
None
self
.
moe_kernel
=
make_int8_moe_kernel
(
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
experts_cls
=
self
.
experts_cls
,
routing_tables
=
layer
.
_maybe_init_expert_routing_tables
(),
shared_experts
=
layer
.
shared_experts
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"FusedMoEQuantConfig | None"
:
quant_config
=
make_int8_moe_quant_config
(
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
)
self
.
_maybe_inject_biases
(
quant_config
,
layer
)
return
quant_config
vllm/model_executor/layers/quantization/online/moe_base.py
0 → 100644
View file @
ad2b1277
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
abstractmethod
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.model_loader.reload.layerwise
import
(
initialize_online_processing
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
OnlineMoEMethodBase
(
FusedMoEMethodBase
):
"""Base for MoE methods that load full-precision weights on meta device
and quantize them after loading via the QeRL layerwise processing system.
"""
uses_meta_device
:
bool
=
True
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# Fused gate_up_proj (column parallel) — full precision on meta device
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
device
=
"meta"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel) — full precision on meta device
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
device
=
"meta"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# BIASES (for models like GPT-OSS that have biased MoE)
if
self
.
moe
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
device
=
"meta"
,
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
device
=
"meta"
,
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
initialize_online_processing
(
layer
)
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
_maybe_inject_biases
(
self
,
quant_config
:
FusedMoEQuantConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
None
:
"""Inject biases into the quant config if the model has them
(e.g. GPT-OSS biased MoE)."""
if
self
.
moe
.
has_bias
:
w13_bias
=
getattr
(
layer
,
"w13_bias"
,
None
)
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
)
if
w13_bias
is
not
None
:
quant_config
.
_w1
.
bias
=
w13_bias
if
w2_bias
is
not
None
:
quant_config
.
_w2
.
bias
=
w2_bias
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalizeModular
|
None
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
uses the new modular kernel "
"initialization logic. This function should not be called."
)
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
def
apply_monolithic
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
routed_scaling_factor
=
layer
.
routed_scaling_factor
,
)
def
apply
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
shared_experts_input
=
shared_experts_input
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment