Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fbd6523a
Unverified
Commit
fbd6523a
authored
Sep 18, 2025
by
Michael Goin
Committed by
GitHub
Sep 18, 2025
Browse files
Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (#21404)
parent
470484a4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
442 additions
and
318 deletions
+442
-318
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+7
-7
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+35
-33
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+96
-95
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+84
-183
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+220
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
fbd6523a
...
...
@@ -805,12 +805,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant
_
method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
)
)
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
# Assume the weight block size has been set by
quant
method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
...
...
@@ -989,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear):
# Note(simon): This is needed for Qwen3's fp8 quantization.
if
isinstance
(
param
,
BlockQuantScaleParameter
):
assert
self
.
quant_method
is
not
None
assert
hasattr
(
self
.
quant_method
,
"quant_config"
)
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
# Assume the weight block size has been set by quant method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
fbd6523a
...
...
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy
,
QuantizationType
)
from
compressed_tensors.transform
import
TransformConfig
from
pydantic
import
BaseModel
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
...
...
@@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig):
else
:
return
False
def
_is_fp4a4_nvfp4
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
):
def
_is_fp4a4_nvfp4
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
):
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
...
...
@@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_tensor_group_quant
and
is_float_type
and
is_4_bits
and
is_group_size_16
and
is_symmetric
)
def
_is_fp4a16_nvfp4
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
):
def
_is_fp4a16_nvfp4
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
):
is_weight_only
=
weight_quant
is
not
None
and
input_quant
is
None
is_tensor_group_quant
=
(
...
...
@@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_only
and
is_tensor_group_quant
and
is_float_type
and
is_4_bits
and
is_group_size_16
and
is_symmetric
)
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
...
...
@@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_tensor
and
weight_quant
.
symmetric
and
is_static
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
...
...
@@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
def
_is_dynamic_token_w4a8_int
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_dynamic_token_w4a8_int
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
is_activation_8_bits
=
input_quant
.
num_bits
==
8
weight_strategy
=
(
...
...
@@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
)
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
# Confirm weights and activations quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
...
...
@@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig):
and
input_quant
.
type
==
QuantizationType
.
FLOAT
)
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
is_tensor_or_channel_or_block_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
BLOCK
])
if
not
(
is_floating_point
and
is_symmetric_weight
and
is_static_weight
and
is_
per_
tensor_or_channel_weight
):
and
is_tensor_or_channel_
or_block_
weight
):
return
False
# Dynamic quantization is always supported if weights supported.
...
...
@@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w4a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w4a8
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
if
not
weight_quant
or
not
input_quant
:
return
False
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
...
...
@@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
)
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w4a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm100
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a8_sm100
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
100
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_fp8_w8a16
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
# Confirm weights quantized.
if
weight_quant
is
None
:
return
False
...
...
@@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
is_tensor_or_channel_or_block_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
BLOCK
])
if
not
(
is_symmetric_weight
and
is_static_weight
# noqa: SIM103
and
is_
per_
tensor_or_channel_weight
):
and
is_tensor_or_channel_
or_block_
weight
):
return
False
# All conditions satisfied.
return
True
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
)
->
bool
:
input_quant_none
=
input_quant
is
None
is_channel_group
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
...
...
@@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_get_scheme_from_parts
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
,
format
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
# use the per-layer format if defined, otherwise, use global format
...
...
@@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig):
CompressedTensorsW8A8Fp8
.
get_min_capability
(),
error
=
False
)
if
is_fp8_w8a8_supported
:
return
CompressedTensorsW8A8Fp8
(
strategy
=
weight_quant
.
strategy
,
weight_quant
=
weight_quant
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
else
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
fbd6523a
...
...
@@ -4,28 +4,41 @@
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
)
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_block_linear
,
check_aiter_fp8_linear_support
,
create_fp8_input_scale
,
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
maybe_post_process_fp8_weight_block
,
process_fp8_weight_block_strategy
,
process_fp8_weight_channel_strategy
,
process_fp8_weight_tensor_strategy
,
validate_fp8_block_shape
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
maybe_create_device_identity
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
strategy_to_parameter_type
=
{
QuantizationStrategy
.
BLOCK
:
BlockQuantScaleParameter
,
QuantizationStrategy
.
CHANNEL
:
ChannelQuantScaleParameter
,
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
}
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
self
.
weight_quant
=
weight_quant
self
.
strategy
=
weight_quant
.
strategy
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
.
PER_TENSOR
\
...
...
@@ -34,61 +47,84 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
act_quant_static
=
self
.
is_static_input_scheme
,
act_quant_group_shape
=
self
.
act_q_group_shape
)
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
use_aiter_and_is_supported
=
check_aiter_fp8_linear_support
()
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# lovelace and up
return
89
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
maybe_create_device_identity
()
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
weight_block_size
=
None
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape
(
layer
,
input_size
,
output_size
,
input_size_per_partition
,
output_partition_sizes
,
self
.
weight_block_size
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
# WEIGHT SCALE
weight_scale
=
create_fp8_scale_parameter
(
strategy_to_parameter_type
[
self
.
strategy
],
output_partition_sizes
,
input_size_per_partition
,
layer
.
weight_block_size
,
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
else
:
weight_scale
=
layer
.
weight_scale
.
data
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_tensor_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
'input_scale'
,
None
)))
weight
=
weight
.
t
()
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
'input_scale'
,
None
)))
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale
)
input_scale
=
None
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight
=
Parameter
(
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
.
data
,
requires_grad
=
False
)
# INPUT SCALE
if
self
.
is_static_input_scheme
and
hasattr
(
layer
,
'input_scale'
):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
...
...
@@ -96,58 +132,23 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else
:
layer
.
input_scale
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
# min requirement for fp8 kernels
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
maybe_post_process_fp8_weight_block
(
layer
,
self
.
cutlass_block_fp8_supported
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
layer
.
weight_block_size
is
not
None
:
return
apply_fp8_block_linear
(
layer
,
input
=
x
,
bias
=
bias
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
fbd6523a
...
...
@@ -4,7 +4,6 @@
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
...
...
@@ -32,8 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
register_moe_scaling_factors
,
rotate_flashinfer_fp8_moe_weights
,
select_cutlass_fp8_gemm_impl
,
swap_w13_to_w31
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
get_col_major_tma_aligned_tensor
,
requant_weight_ue8m0_inplace
,
should_use_deepgemm_for_fp8_linear
)
apply_fp8_block_linear
,
check_aiter_fp8_linear_support
,
create_fp8_input_scale
,
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
get_col_major_tma_aligned_tensor
,
maybe_post_process_fp8_weight_block
,
process_fp8_weight_block_strategy
,
process_fp8_weight_tensor_strategy
,
requant_weight_ue8m0_inplace
,
validate_fp8_block_shape
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_moe_fp8_layer_for_marlin
)
...
...
@@ -42,8 +45,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
all_close_1d
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
...
...
@@ -233,14 +235,10 @@ class Fp8LinearMethod(LinearMethodBase):
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self
.
use_aiter_and_is_supported
=
(
current_platform
.
is_rocm
()
and
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_LINEAR
and
current_platform
.
is_fp8_fnuz
())
self
.
use_aiter_and_is_supported
=
check_aiter_fp8_linear_support
()
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
# Use per-token quantization for better perf if dynamic and cutlass
if
not
self
.
act_q_static
and
cutlass_fp8_supported
():
...
...
@@ -273,48 +271,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_block_size
=
None
if
self
.
block_quant
:
tp_size
=
getattr
(
layer
,
"tp_size"
,
get_tensor_model_parallel_world_size
())
assert
self
.
quant_config
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# Required by row parallel
if
(
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
and
input_size_per_partition
%
block_k
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# Required by column parallel or enabling merged weights
is_tp_split
=
(
tp_size
>
1
and
output_size
//
output_size_per_partition
==
tp_size
)
is_merged_gemm
=
len
(
output_partition_sizes
)
>
1
if
is_tp_split
or
is_merged_gemm
:
sizes_to_check
=
output_partition_sizes
if
not
is_tp_split
and
is_merged_gemm
:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check
=
output_partition_sizes
[:
-
1
]
for
output_partition_size
in
sizes_to_check
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
assert
self
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
weight_block_size
validate_fp8_block_shape
(
layer
,
input_size
,
output_size
,
input_size_per_partition
,
output_partition_sizes
,
self
.
weight_block_size
)
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
else
:
# For non-serialized checkpoints, use original dtype
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight
_dtype
),
dtype
=
params
_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
...
...
@@ -325,154 +299,87 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
if
not
self
.
block_quant
:
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
=
create_fp8_scale_parameter
(
PerTensorScaleParameter
,
output_partition_sizes
,
input_size_per_partition
,
None
,
weight_loader
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
layer
.
register_parameter
(
"weight_scale"
,
scale
)
else
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
scale
=
BlockQuantScaleParameter
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
assert
not
self
.
act_q_static
assert
self
.
weight_block_size
is
not
None
scale
=
create_fp8_scale_parameter
(
BlockQuantScaleParameter
,
output_partition_sizes
,
input_size_per_partition
,
self
.
weight_block_size
,
weight_loader
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
# The weight_scale_inv name is intentional for deepseekv3
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
if
self
.
act_q_static
:
scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"input_scale"
})
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
_maybe_pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
size_k_first
=
True
input_scale
=
None
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
not
self
.
act_q_static
size_k_first
=
False
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale_inv
,
_
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale_inv
)
else
:
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight
=
self
.
_maybe_pad_weight
(
weight
)
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
requires_grad
=
False
)
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale_inv
)
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter
del
layer
.
weight_scale_inv
# If checkpoint not serialized fp8, quantize the weights.
elif
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
# layer.input_scale is None indicates dynamic quant and scale is
# computed from input.
layer
.
input_scale
=
None
# If checkpoint is fp8, handle that there are N scales for N
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
# Dequant -> Quant with max scale so we can run per tensor.
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
'input_scale'
,
None
)))
if
self
.
act_q_static
:
assert
input_scale
is
not
None
input_scale
=
input_scale
.
max
()
weight
=
weight
.
t
()
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
weight
=
self
.
_maybe_pad_weight
(
weight
)
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
()
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_schem
e
=
=
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
()
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
input_scal
e
=
Parameter
(
input_scale
,
requires_grad
=
False
)
if
input_scale
is
not
None
else
None
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
)
# Activations not quantized for marlin.
del
layer
.
input_scale
return
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if
is_deep_gemm_e8m0_used
()
and
self
.
block_quant
:
assert
layer
.
weight_block_size
is
not
None
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
if
hasattr
(
layer
,
"weight_scale_inv"
)
else
layer
.
weight_scale
.
data
,
block_sz
,
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
if
(
self
.
block_quant
and
current_platform
.
is_device_capability
(
90
)
and
self
.
cutlass_block_fp8_supported
and
not
should_use_deepgemm_for_fp8_linear
(
torch
.
bfloat16
,
layer
.
weight
)):
layer
.
weight_scale_inv
=
Parameter
(
layer
.
weight_scale_inv
.
data
.
T
.
contiguous
(),
requires_grad
=
False
)
if
self
.
block_quant
:
maybe_post_process_fp8_weight_block
(
layer
,
self
.
cutlass_block_fp8_supported
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -490,18 +397,12 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
)
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
return
torch
.
ops
.
vllm
.
apply_w8a8_block_fp8_linear
(
return
apply_fp8_block_linear
(
layer
,
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
...
...
@@ -528,7 +429,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
super
().
__init__
(
layer
.
moe_config
)
self
.
layer
=
layer
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
flashinfer_moe_backend
:
Optional
[
FlashinferMoeBackend
]
=
None
self
.
fused_experts
:
Optional
[
...
...
@@ -590,12 +492,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
assert
self
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
weight_block_size
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
self
.
weight_block_size
[
0
],
self
.
weight_block_size
[
1
],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
...
...
@@ -952,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
self
.
quant_config
.
weight_block_size
,
False
)
self
.
weight_block_size
,
False
)
return
BatchedTritonOrDeepGemmExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
...
...
@@ -969,8 +871,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
quant_config
.
weight_block_size
,
False
)
self
.
__class__
.
__name__
,
self
.
weight_block_size
,
False
)
return
TritonOrDeepGemmExperts
(
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
self
.
allow_deep_gemm
,
...
...
@@ -988,7 +889,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
weight_block_size
,
)
def
apply
(
...
...
@@ -1046,7 +947,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size
=
layer
.
intermediate_size_per_partition
,
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
weight_block_size
,
routed_scaling
=
routed_scaling_factor
,
)
else
:
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
fbd6523a
...
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
...
...
@@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace(
# Write back the results in-place.
w_q
.
copy_
(
w_requant
)
s_old
.
copy_
(
s_requant
)
def
check_aiter_fp8_linear_support
()
->
bool
:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
return
(
current_platform
.
is_rocm
()
and
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_LINEAR
and
current_platform
.
is_fp8_fnuz
())
def
_maybe_pad_fp8_weight
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
import
torch.nn.functional
as
F
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
validate_fp8_block_shape
(
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_size
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
block_size
:
list
[
int
])
->
None
:
"""Validate block quantization shapes for tensor parallelism."""
from
vllm.distributed
import
get_tensor_model_parallel_world_size
tp_size
=
getattr
(
layer
,
"tp_size"
,
get_tensor_model_parallel_world_size
())
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
# Required by row parallel
if
(
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
and
input_size_per_partition
%
block_k
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
"is not divisible by weight quantization block_k =
{
block_k
}
."
)
# Required by column parallel or enabling merged weights
is_tp_split
=
(
tp_size
>
1
and
output_size
//
sum
(
output_partition_sizes
)
==
tp_size
)
is_merged_gemm
=
len
(
output_partition_sizes
)
>
1
if
is_tp_split
or
is_merged_gemm
:
sizes_to_check
=
output_partition_sizes
if
not
is_tp_split
and
is_merged_gemm
:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check
=
output_partition_sizes
[:
-
1
]
for
output_partition_size
in
sizes_to_check
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
def
create_fp8_weight_parameter
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create FP8 weight parameter."""
from
vllm.model_executor.parameter
import
ModelWeightParameter
return
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
def
create_fp8_scale_parameter
(
parameter_type
:
torch
.
nn
.
Parameter
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
block_size
:
Optional
[
list
[
int
]],
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create scale parameter based on quantization strategy."""
if
parameter_type
==
ChannelQuantScaleParameter
:
scale
=
parameter_type
(
data
=
torch
.
empty
(
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
elif
parameter_type
==
BlockQuantScaleParameter
:
assert
block_size
is
not
None
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
output_size_per_partition
=
sum
(
output_partition_sizes
)
scale
=
parameter_type
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
elif
parameter_type
==
PerTensorScaleParameter
:
scale
=
parameter_type
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
else
:
raise
ValueError
(
f
"Unknown parameter type:
{
parameter_type
}
"
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
return
scale
def
create_fp8_input_scale
(
output_partition_sizes
:
list
[
int
],
weight_loader
:
Optional
[
Callable
])
->
torch
.
nn
.
Parameter
:
"""Create input scale parameter for static activation quantization."""
from
vllm.model_executor.parameter
import
PerTensorScaleParameter
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
return
scale
def
process_fp8_weight_tensor_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
list
[
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Process weights for tensor-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
input_scale
)
# Requantize with max scale
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
logical_widths
,
)
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
,
input_scale
def
process_fp8_weight_channel_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Process weights for channel-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
input_scale
)
return
weight
,
weight_scale
,
input_scale
def
process_fp8_weight_block_strategy
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process weights for block-wise quantization strategy."""
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
)
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
def
maybe_post_process_fp8_weight_block
(
layer
:
torch
.
nn
.
Module
,
cutlass_block_fp8_supported
:
bool
):
assert
layer
.
weight_block_size
is
not
None
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
should_use_deepgemm_for_fp8_linear
)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if
is_deep_gemm_e8m0_used
():
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
,
block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif
(
current_platform
.
is_device_capability
(
90
)
and
cutlass_block_fp8_supported
and
not
should_use_deepgemm_for_fp8_linear
(
torch
.
bfloat16
,
layer
.
weight
)):
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
.
T
.
contiguous
(),
requires_grad
=
False
)
def
apply_fp8_block_linear
(
layer
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
cutlass_block_fp8_supported
:
bool
,
use_aiter_and_is_supported
:
bool
)
->
torch
.
Tensor
:
"""Apply block-wise FP8 linear operation."""
assert
layer
.
weight_block_size
is
not
None
return
torch
.
ops
.
vllm
.
apply_w8a8_block_fp8_linear
(
input
=
input
,
weight
=
layer
.
weight
,
block_size
=
layer
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
use_aiter_and_is_supported
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment