Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fd0e3772
Unverified
Commit
fd0e3772
authored
Jan 30, 2026
by
Michael Goin
Committed by
GitHub
Jan 30, 2026
Browse files
Support FP8 block quant for CompressedTensorsW8A16Fp8 (#33280)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
f857a03f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
64 deletions
+74
-64
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+70
-59
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+0
-2
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+2
-1
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
fd0e3772
...
@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# note: input_quant will be present for converted models;
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
# will be ignored during inference post loading
return
CompressedTensorsW8A16Fp8
(
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
weight_quant
=
weight_quant
,
is_static_input_scheme
=
not
input_quant
.
dynamic
,
is_static_input_scheme
=
not
input_quant
.
dynamic
,
)
)
...
@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
return
CompressedTensorsW8A16Fp8
(
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
weight_quant
=
weight_quant
,
is_static_input_scheme
=
is_static_input_scheme
,
is_static_input_scheme
=
is_static_input_scheme
,
)
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
fd0e3772
...
@@ -4,11 +4,17 @@
...
@@ -4,11 +4,17 @@
from
collections.abc
import
Callable
from
collections.abc
import
Callable
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
QuantizationArgs
,
QuantizationStrategy
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
process_fp8_weight_block_strategy
,
validate_fp8_block_shape
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_fp8_layer_for_marlin
,
...
@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise
,
convert_to_channelwise
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
from
vllm.model_executor.utils
import
replace_parameter
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
SUPPORTED_STRATEGIES
=
[
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
TENSOR
]
strategy_to_parameter_type
=
{
QuantizationStrategy
.
BLOCK
:
BlockQuantScaleParameter
,
QuantizationStrategy
.
CHANNEL
:
ChannelQuantScaleParameter
,
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
}
class
CompressedTensorsW8A16Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A16Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
weight_quant
=
weight_quant
self
.
strategy
=
weight_quant
.
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# turing and up
# turing and up
return
75
return
75
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
else
:
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# Weights must be transposed for marlin
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
if
self
.
is_static_input_scheme
:
# required by torch.compile to be torch.nn.Parameter
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
prepare_fp8_layer_for_marlin
(
layer
)
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
weight_loader
:
Callable
,
**
kwargs
,
**
kwargs
,
...
@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
...
@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer
.
orig_dtype
=
params_dtype
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
layer
.
weight_block_size
=
None
# WEIGHT
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
weight
=
ModelWeightParameter
(
assert
self
.
weight_block_size
is
not
None
data
=
torch
.
empty
(
layer
.
weight_block_size
=
self
.
weight_block_size
output_size_per_partition
,
# Validate block quantization shapes
validate_fp8_block_shape
(
layer
,
input_size
,
output_size
,
input_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
output_partition_sizes
,
),
self
.
weight_block_size
,
input_dim
=
1
,
)
output_dim
=
0
,
weight_loader
=
weight_loader
,
# WEIGHT
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
# WEIGHT SCALE
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_fp8_scale_parameter
(
weight_scale
=
ChannelQuantScaleParameter
(
strategy_to_parameter_type
[
self
.
strategy
],
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_partition_sizes
,
output_dim
=
0
,
input_size_per_partition
,
weight_loader
=
weight_loader
,
layer
.
weight_block_size
,
)
weight_loader
,
elif
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
else
:
raise
ValueError
(
f
"Unsupported weight strategy=
{
self
.
strategy
}
, "
f
"supported strategies are
{
SUPPORTED_STRATEGIES
}
"
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE (to deal with converted checkpoints)
# INPUT SCALE (to deal with converted checkpoints)
...
@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
...
@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
size_k_first
=
True
# TODO(rob): refactor block quant into separate class.
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
size_k_first
=
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
weight
,
weight_scale
)
else
:
# Weights must be transposed for marlin
weight
=
weight
.
t
()
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
# If we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
# Update layer with new values
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
=
size_k_first
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
fd0e3772
...
@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase):
None
,
None
,
weight_loader
,
weight_loader
,
)
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
layer
.
register_parameter
(
"weight_scale"
,
scale
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
else
:
else
:
assert
not
self
.
act_q_static
assert
not
self
.
act_q_static
...
@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
weight_block_size
,
self
.
weight_block_size
,
weight_loader
,
weight_loader
,
)
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
# The weight_scale_inv name is intentional for deepseekv3
# The weight_scale_inv name is intentional for deepseekv3
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
fd0e3772
...
@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import (
...
@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
...
@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter(
...
@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter(
raise
ValueError
(
f
"Unknown parameter type:
{
parameter_type
}
"
)
raise
ValueError
(
f
"Unknown parameter type:
{
parameter_type
}
"
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
return
scale
return
scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment