Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18fecc35
Unverified
Commit
18fecc35
authored
Jul 17, 2024
by
Robert Shaw
Committed by
GitHub
Jul 18, 2024
Browse files
[ Kernel ] Fp8 Channelwise Weight Support (#6487)
parent
b5af8c22
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
35 deletions
+76
-35
vllm/config.py
vllm/config.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+11
-7
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+53
-27
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+10
-0
No files found.
vllm/config.py
View file @
18fecc35
...
@@ -238,7 +238,8 @@ class ModelConfig:
...
@@ -238,7 +238,8 @@ class ModelConfig:
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
if
(
self
.
quantization
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
)):
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"compressed_tensors"
)):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
18fecc35
...
@@ -13,7 +13,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -13,7 +13,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_first_name_or_class_match
)
QuantizationType
,
find_first_name_or_class_match
,
is_activation_quantization_format
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -132,10 +133,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -132,10 +133,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_weight
=
(
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
if
not
(
is_symmetric_weight
and
is_static_weight
if
not
(
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_weight
):
and
is_per_tensor_
or_channel_
weight
):
return
False
return
False
# Dynamic quantization is always supported if weights supported.
# Dynamic quantization is always supported if weights supported.
...
@@ -167,6 +169,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -167,6 +169,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
# Detect If Mixed Precision
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
self
.
_check_gptq_and_marlin_can_run
()
self
.
_check_gptq_and_marlin_can_run
()
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
...
@@ -182,11 +185,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -182,11 +185,12 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
)
if
(
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
or
# Detect If Activation Quantization.
self
.
quant_format
==
Compress
ion
F
ormat
.
float_quantized
.
value
):
if
is_activation_quantizat
ion
_f
ormat
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8
(
return
CompressedTensorsW8A8Fp8
(
input_dynamic
=
input_quant
.
dynamic
)
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
not
input_quant
.
dynamic
))
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
return
CompressedTensorsW8A8Int8
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
18fecc35
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
apply_fp8_linear
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -14,39 +18,56 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
...
@@ -14,39 +18,56 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
input_dynamic
:
bool
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
input_dynamic
=
input_dynamic
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# On Lovelace, fail for now if channelwise.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# TODO: (@tms) fallback
# scales being passed to the kernel), we requantize with a single scale.
if
(
not
self
.
cutlass_fp8_supported
and
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
):
raise
ValueError
(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper."
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Dequant -> Quant with max scale.
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# If channelwise, scales are already lined up, so just transpose.
if
self
.
input_dynamic
:
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
layer
.
input_scale
=
None
assert
self
.
cutlass_fp8_supported
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
else
:
else
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
max
(),
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
**
kwargs
):
del
params_dtype
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
...
@@ -63,12 +84,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -63,12 +84,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
})
# WEIGHT SCALE
# WEIGHT SCALE
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
create_per_tensor_scale_param
(
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
# INPUT SCALE
if
not
self
.
i
nput_dynamic
:
if
self
.
i
s_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
18fecc35
...
@@ -9,6 +9,7 @@ from torch.nn import Module
...
@@ -9,6 +9,7 @@ from torch.nn import Module
class
CompressionFormat
(
Enum
):
class
CompressionFormat
(
Enum
):
dense
=
"dense"
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
sparse_bitmask
=
"sparse-bitmask"
naive_quantized
=
"naive-quantized"
float_quantized
=
"float-quantized"
float_quantized
=
"float-quantized"
int_quantized
=
"int-quantized"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
pack_quantized
=
"pack-quantized"
...
@@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel):
...
@@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel):
)
)
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
_ACTIVATION_QUANTIZATION_FORMATS
=
[
CompressionFormat
.
naive_quantized
.
value
,
CompressionFormat
.
int_quantized
.
value
,
CompressionFormat
.
float_quantized
.
value
]
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
def
find_first_name_or_class_match
(
def
find_first_name_or_class_match
(
name
:
str
,
name
:
str
,
module
:
Module
,
module
:
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment