Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fbd6523a
Unverified
Commit
fbd6523a
authored
Sep 18, 2025
by
Michael Goin
Committed by
GitHub
Sep 18, 2025
Browse files
Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (#21404)
parent
470484a4
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
442 additions
and
318 deletions
+442
-318
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+7
-7
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+35
-33
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+96
-95
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+84
-183
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+220
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
fbd6523a
...
...
@@ -805,12 +805,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant
_
method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
)
)
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
# Assume the weight block size has been set by
quant
method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
...
...
@@ -989,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear):
# Note(simon): This is needed for Qwen3's fp8 quantization.
if
isinstance
(
param
,
BlockQuantScaleParameter
):
assert
self
.
quant_method
is
not
None
assert
hasattr
(
self
.
quant_method
,
"quant_config"
)
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
# Assume the weight block size has been set by quant method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
fbd6523a
...
...
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy
,
QuantizationType
)
from
compressed_tensors.transform
import
TransformConfig
from
pydantic
import
BaseModel
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
...
...
@@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig):
else
:
return
False
def
_is_fp4a4_nvfp4
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
):
def
_is_fp4a4_nvfp4
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
):
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
...
...
@@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_tensor_group_quant
and
is_float_type
and
is_4_bits
and
is_group_size_16
and
is_symmetric
)
def
_is_fp4a16_nvfp4
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
):
def
_is_fp4a16_nvfp4
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
):
is_weight_only
=
weight_quant
is
not
None
and
input_quant
is
None
is_tensor_group_quant
=
(
...
...
@@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_only
and
is_tensor_group_quant
and
is_float_type
and
is_4_bits
and
is_group_size_16
and
is_symmetric
)
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
...
...
@@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_tensor
and
weight_quant
.
symmetric
and
is_static
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
...
...
@@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
def
_is_dynamic_token_w4a8_int
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_dynamic_token_w4a8_int
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
is_activation_8_bits
=
input_quant
.
num_bits
==
8
weight_strategy
=
(
...
...
@@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
)
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
# Confirm weights and activations quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
...
...
@@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig):
and
input_quant
.
type
==
QuantizationType
.
FLOAT
)
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
is_tensor_or_channel_or_block_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
BLOCK
])
if
not
(
is_floating_point
and
is_symmetric_weight
and
is_static_weight
and
is_
per_
tensor_or_channel_weight
):
and
is_tensor_or_channel_
or_block_
weight
):
return
False
# Dynamic quantization is always supported if weights supported.
...
...
@@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w4a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w4a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
if
not
weight_quant
or
not
input_quant
:
return
False
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
...
...
@@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
)
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w4a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm100
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8_sm100
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
100
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a16
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
# Confirm weights quantized.
if
weight_quant
is
None
:
return
False
...
...
@@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
is_tensor_or_channel_or_block_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
BLOCK
])
if
not
(
is_symmetric_weight
and
is_static_weight
# noqa: SIM103
and
is_
per_
tensor_or_channel_weight
):
and
is_tensor_or_channel_
or_block_
weight
):
return
False
# All conditions satisfied.
return
True
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
input_quant_none
=
input_quant
is
None
is_channel_group
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
...
...
@@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_get_scheme_from_parts
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
,
format
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
# use the per-layer format if defined, otherwise, use global format
...
...
@@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig):
CompressedTensorsW8A8Fp8
.
get_min_capability
(),
error
=
False
)
if
is_fp8_w8a8_supported
:
return
CompressedTensorsW8A8Fp8
(
strategy
=
weight_quant
.
strategy
,
weight_quant
=
weight_quant
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
else
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
fbd6523a
...
...
@@ -4,28 +4,41 @@
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
)
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_block_linear
,
check_aiter_fp8_linear_support
,
create_fp8_input_scale
,
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
maybe_post_process_fp8_weight_block
,
process_fp8_weight_block_strategy
,
process_fp8_weight_channel_strategy
,
process_fp8_weight_tensor_strategy
,
validate_fp8_block_shape
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
maybe_create_device_identity
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
strategy_to_parameter_type
=
{
QuantizationStrategy
.
BLOCK
:
BlockQuantScaleParameter
,
QuantizationStrategy
.
CHANNEL
:
ChannelQuantScaleParameter
,
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
}
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
self
.
weight_quant
=
weight_quant
self
.
strategy
=
weight_quant
.
strategy
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
.
PER_TENSOR
\
...
...
@@ -34,120 +47,108 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
act_quant_static
=
self
.
is_static_input_scheme
,
act_quant_group_shape
=
self
.
act_q_group_shape
)
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
use_aiter_and_is_supported
=
check_aiter_fp8_linear_support
()
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# lovelace and up
return
89
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
else
:
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# INPUT SCALE
if
self
.
is_static_input_scheme
and
hasattr
(
layer
,
'input_scale'
):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
weight_block_size
=
None
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape
(
layer
,
input_size
,
output_size
,
input_size_per_partition
,
output_partition_sizes
,
self
.
weight_block_size
)
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
# min requirement for fp8 kernels
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
weight_scale
=
create_fp8_scale_parameter
(
strategy_to_parameter_type
[
self
.
strategy
],
output_partition_sizes
,
input_size_per_partition
,
layer
.
weight_block_size
,
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
input_scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_tensor_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
'input_scale'
,
None
)))
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
'input_scale'
,
None
)))
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale
)
input_scale
=
None
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight
=
Parameter
(
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
.
data
,
requires_grad
=
False
)
# INPUT SCALE
if
self
.
is_static_input_scheme
and
hasattr
(
layer
,
'input_scale'
):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
maybe_post_process_fp8_weight_block
(
layer
,
self
.
cutlass_block_fp8_supported
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
layer
.
weight_block_size
is
not
None
:
return
apply_fp8_block_linear
(
layer
,
input
=
x
,
bias
=
bias
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
fbd6523a
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
fbd6523a
...
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
...
...
@@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace(
# Write back the results in-place.
w_q
.
copy_
(
w_requant
)
s_old
.
copy_
(
s_requant
)
def
check_aiter_fp8_linear_support
()
->
bool
:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
return
(
current_platform
.
is_rocm
()
and
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_LINEAR
and
current_platform
.
is_fp8_fnuz
())
def
_maybe_pad_fp8_weight
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
import
torch.nn.functional
as
F
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
validate_fp8_block_shape
(
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_size
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
block_size
:
list
[
int
])
->
None
:
"""Validate block quantization shapes for tensor parallelism."""
from
vllm.distributed
import
get_tensor_model_parallel_world_size
tp_size
=
getattr
(
layer
,
"tp_size"
,
get_tensor_model_parallel_world_size
())
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
# Required by row parallel
if
(
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
and
input_size_per_partition
%
block_k
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
"is not divisible by weight quantization block_k =
{
block_k
}
."
)
# Required by column parallel or enabling merged weights
is_tp_split
=
(
tp_size
>
1
and
output_size
//
sum
(
output_partition_sizes
)
==
tp_size
)
is_merged_gemm
=
len
(
output_partition_sizes
)
>
1
if
is_tp_split
or
is_merged_gemm
:
sizes_to_check
=
output_partition_sizes
if
not
is_tp_split
and
is_merged_gemm
:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check
=
output_partition_sizes
[:
-
1
]
for
output_partition_size
in
sizes_to_check
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
def
create_fp8_weight_parameter
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create FP8 weight parameter."""
from
vllm.model_executor.parameter
import
ModelWeightParameter
return
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
def
create_fp8_scale_parameter
(
parameter_type
:
torch
.
nn
.
Parameter
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
block_size
:
Optional
[
list
[
int
]],
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create scale parameter based on quantization strategy."""
if
parameter_type
==
ChannelQuantScaleParameter
:
scale
=
parameter_type
(
data
=
torch
.
empty
(
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
elif
parameter_type
==
BlockQuantScaleParameter
:
assert
block_size
is
not
None
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
output_size_per_partition
=
sum
(
output_partition_sizes
)
scale
=
parameter_type
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
elif
parameter_type
==
PerTensorScaleParameter
:
scale
=
parameter_type
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
else
:
raise
ValueError
(
f
"Unknown parameter type:
{
parameter_type
}
"
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
return
scale
def
create_fp8_input_scale
(
output_partition_sizes
:
list
[
int
],
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create input scale parameter for static activation quantization."""
from
vllm.model_executor.parameter
import
PerTensorScaleParameter
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
return
scale
def
process_fp8_weight_tensor_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
list
[
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Process weights for tensor-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
input_scale
)
# Requantize with max scale
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
logical_widths
,
)
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
,
input_scale
def
process_fp8_weight_channel_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Process weights for channel-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
input_scale
)
return
weight
,
weight_scale
,
input_scale
def
process_fp8_weight_block_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process weights for block-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
def
maybe_post_process_fp8_weight_block
(
layer
:
torch
.
nn
.
Module
,
cutlass_block_fp8_supported
:
bool
):
assert
layer
.
weight_block_size
is
not
None
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
should_use_deepgemm_for_fp8_linear
)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if
is_deep_gemm_e8m0_used
():
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
,
block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif
(
current_platform
.
is_device_capability
(
90
)
and
cutlass_block_fp8_supported
and
not
should_use_deepgemm_for_fp8_linear
(
torch
.
bfloat16
,
layer
.
weight
)):
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
.
T
.
contiguous
(),
requires_grad
=
False
)
def
apply_fp8_block_linear
(
layer
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
cutlass_block_fp8_supported
:
bool
,
use_aiter_and_is_supported
:
bool
)
->
torch
.
Tensor
:
"""Apply block-wise FP8 linear operation."""
assert
layer
.
weight_block_size
is
not
None
return
torch
.
ops
.
vllm
.
apply_w8a8_block_fp8_linear
(
input
=
input
,
weight
=
layer
.
weight
,
block_size
=
layer
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
use_aiter_and_is_supported
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment