Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
abfe705a
Unverified
Commit
abfe705a
authored
Jul 07, 2024
by
Robert Shaw
Committed by
GitHub
Jul 07, 2024
Browse files
[ Misc ] Support Fp8 via `llm-compressor` (#6110)
Co-authored-by:
Robert Shaw
<
rshaw@neuralmagic
>
parent
333306a2
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
603 additions
and
263 deletions
+603
-263
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
...figs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml
...lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml
+1
-1
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
...igs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+2
-0
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+1
-1
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
+2
-1
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+27
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+50
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+19
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+87
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+85
-0
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+1
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+41
-228
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+5
-9
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+97
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+163
-0
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
0 → 100644
View file @
abfe705a
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.752
-
name
:
"
exact_match,flexible-extract"
value
:
0.752
limit
:
250
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml
View file @
abfe705a
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-
hf
-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-
vllm
-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
model_name
:
"
neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
tasks
:
-
name
:
"
gsm8k"
...
...
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
0 → 100644
View file @
abfe705a
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.728
-
name
:
"
exact_match,flexible-extract"
value
:
0.728
limit
:
250
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
abfe705a
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
View file @
abfe705a
...
...
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval
--model
vllm
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,add_bos_token
=
true
\
--tasks
gsm8k
--num_fewshot
$FEWSHOT
--limit
$LIMIT
\
--batch_size
$BATCH_SIZE
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
View file @
abfe705a
...
...
@@ -24,7 +24,8 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
def
launch_lm_eval
(
eval_config
):
model_args
=
f
"pretrained=
{
eval_config
[
'model_name'
]
}
,"
\
f
"tensor_parallel_size=
{
TP_SIZE
}
"
f
"tensor_parallel_size=
{
TP_SIZE
}
,"
\
f
"add_bos_token=true"
results
=
lm_eval
.
simple_evaluate
(
model
=
"vllm"
,
...
...
tests/quantization/test_compressed_tensors.py
View file @
abfe705a
...
...
@@ -9,7 +9,8 @@ import torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationType
)
...
...
@@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
Int8
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
is_static_input_scheme
expected_type
=
(
torch
.
int8
if
quant_type
==
QuantizationType
.
INT
else
torch
.
float8_e4m3fn
)
expected_type
=
torch
.
int8
assert
qkv_proj
.
weight
.
dtype
is
expected_type
assert
o_proj
.
weight
.
dtype
is
expected_type
...
...
@@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
...
...
@@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
assert
output
def
test_compressed_tensors_fp8
(
vllm_runner
):
model_path
=
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
)
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
# should be scalars after processing
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
abfe705a
...
...
@@ -9,10 +9,11 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
QuantizationType
,
find_first_name_or_class_match
)
from
vllm.platforms
import
current_platform
...
...
@@ -117,6 +118,40 @@ class CompressedTensorsConfig(QuantizationConfig):
return
is_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
# Confirm weights and activations quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
# Confirm we have floating points.
if
not
(
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
False
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_weight
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
if
not
(
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_weight
):
return
False
# Dynamic quantization is always supported if weights supported.
if
input_quant
.
dynamic
:
return
True
# Confirm activation scheme is supported.
is_symmetric_activation
=
input_quant
.
symmetric
is_per_tensor_activation
=
(
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
if
not
(
is_symmetric_activation
and
is_per_tensor_activation
):
return
False
# All conditions satisfied.
return
True
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
...
...
@@ -147,13 +182,20 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
if
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
:
if
(
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
or
self
.
quant_format
==
CompressionFormat
.
float_quantized
.
value
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8
(
input_dynamic
=
input_quant
.
dynamic
)
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8
(
strategy
=
weight_quant
.
strategy
,
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
True
)
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8
(
strategy
=
weight_quant
.
strategy
,
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
False
)
raise
NotImplementedError
(
...
...
@@ -187,7 +229,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
return
layer
.
scheme
.
process_weights_after_loading
(
layer
)
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
abfe705a
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
# noqa: F401
from
.compressed_tensors_unquantized
import
(
# noqa: F401
CompressedTensorsUnquantized
)
from
.compressed_tensors_w4a16_24
import
(
# noqa: F401
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8
import
CompressedTensorsW8A8
# noqa: F401
from
.compressed_tensors_wNa16
import
WNA16_SUPPORTED_BITS
# noqa: F401
from
.compressed_tensors_wNa16
import
CompressedTensorsWNA16
# noqa: F401
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_unquantized
import
CompressedTensorsUnquantized
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
)
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
0 → 100644
View file @
abfe705a
from
typing
import
Callable
,
List
,
Optional
import
torch
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
input_dynamic
:
bool
):
self
.
input_dynamic
=
input_dynamic
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Dequant -> Quant with max scale.
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
max_w_scale
,
requires_grad
=
False
)
if
self
.
input_dynamic
:
layer
.
input_scale
=
None
else
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
del
params_dtype
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
not
self
.
input_dynamic
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py
→
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8
_int8
.py
View file @
abfe705a
from
typing
import
Callable
,
List
,
Tuple
,
Union
from
typing
import
Callable
,
List
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_int8_linear
,
convert_to_channelwise
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
CompressedTensorsW8A8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8
Int8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we convert to the per-channel case.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
(
self
.
strategy
==
QuantizationStrategy
.
TENSOR
and
len
(
self
.
logical_widths
)
>
1
):
# WEIGHT
# Cutlass kernels need transposed weight.
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# Load the N per-tensor scales into the channelwise buffer.
weight_scale_channel
=
torch
.
empty
(
(
sum
(
self
.
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
layer
.
weight_scale
.
device
)
start
=
0
for
idx
,
logical_width
in
enumerate
(
self
.
logical_widths
):
end
=
start
+
logical_width
weight_scale_channel
[
start
:
end
,
:]
=
layer
.
weight_scale
[
idx
]
start
=
end
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
self
.
logical_widths
)
>
1
if
is_fused_module
and
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
self
.
logical_widths
)
layer
.
weight_scale
=
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_channel
,
# INPUT SCALE
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
# transpose weights for cutlass.
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
...
...
@@ -49,27 +49,6 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
**
kwargs
):
self
.
logical_widths
=
output_partition_sizes
# WEIGHT SCALE
shape
:
Union
[
Tuple
[
int
],
Tuple
[
int
,
int
]]
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
shape
=
(
sum
(
self
.
logical_widths
),
1
)
else
:
shape
=
(
len
(
self
.
logical_widths
),
)
weight_scale
=
Parameter
(
torch
.
empty
(
*
shape
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"output_dim"
:
0
,
})
else
:
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"needs_scalar_to_array"
:
True
,
})
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
...
...
@@ -82,28 +61,25 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT SCALE
# Static quantization: load from disk.
if
self
.
is_static_input_scheme
:
input_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
set_weight_attrs
(
input_scale
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
# Dynamic quantization: set to None.
else
:
layer
.
input_scale
=
None
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
x
,
layer
.
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
)
return
apply_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
)
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
abfe705a
...
...
@@ -9,6 +9,7 @@ from torch.nn import Module
class
CompressionFormat
(
Enum
):
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
float_quantized
=
"float-quantized"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
marlin_24
=
"marlin-24"
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
abfe705a
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
...
...
@@ -11,11 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
pack_fp8_to_int32
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
...
...
@@ -25,13 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
def
cutlass_fp8_supported
()
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
...
...
@@ -117,23 +110,6 @@ class Fp8LinearMethod(LinearMethodBase):
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
def
_create_scale_param
(
self
,
scale_name
:
str
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"needs_scalar_to_array"
:
True
,
})
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -147,7 +123,6 @@ class Fp8LinearMethod(LinearMethodBase):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
...
...
@@ -173,144 +148,50 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
self
.
_create_scale_param
(
scale_name
=
"weight_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale_name
=
"input_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if
self
.
use_marlin
:
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
prepare_layer_for_marlin
(
self
,
layer
:
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
assert
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
READY
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
layer
.
weight_scale
=
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
(
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
return
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else
:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. As a result, we skip dequant -> requant
# since we already have quantized QKV together.
# Sample Model with fused checkpoint:
# * nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
layer
.
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
())
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# Dequant -> Quant with max scale.
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
input_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -318,65 +199,22 @@ class Fp8LinearMethod(LinearMethodBase):
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
layer
.
weight
,
b_scales
=
layer
.
weight_scale
,
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
bias
=
bias
)
return
output
.
reshape
(
out_shape
)
else
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
...
@@ -399,8 +237,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
layer
.
process_after_load
=
True
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
...
...
@@ -465,9 +301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
(
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
):
return
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
...
@@ -531,7 +364,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:]
=
per_tensor
_quant
ize
(
start
:
start
+
shard_size
,
:]
,
_
=
ops
.
scaled_fp8
_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
...
...
@@ -596,23 +429,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
abfe705a
...
...
@@ -11,20 +11,16 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_K
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
GPTQ_MARLIN_SUPPORTED_SYM
,
GPTQ_MARLIN_TILE
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
abfe705a
"""This file is used for /tests and /benchmarks"""
import
random
from
typing
import
Optional
import
numpy
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.format_24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
vllm.model_executor.layers.quantization.utils.marlin_24_perms
import
(
...
...
@@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
def
is_marlin_supported
():
...
...
@@ -22,7 +32,92 @@ def is_marlin_supported():
return
capability
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
workspace
=
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
size_n
,
size_k
=
size_k
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
num_bits
=
8
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
scale_perm
=
marlin_scale_perm
[
num_bits
],
scale_perm_single
=
marlin_scale_perm_single
[
num_bits
])
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
0 → 100644
View file @
abfe705a
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
def
cutlass_fp8_supported
()
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
create_per_tensor_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"needs_scalar_to_array"
:
True
,
**
extra_weight_attrs
})
return
scale
def
create_per_channel_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"output_dim"
:
0
,
**
extra_weight_attrs
})
return
scale
def
convert_to_channelwise
(
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Create channelwise buffer
weight_scale_channel
=
torch
.
empty
((
sum
(
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
weight_scale
.
device
)
# Expand each scale to match the size of each logical matrix.
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_scale_channel
[
start
:
end
,
:]
=
weight_scale
[
idx
]
start
=
end
return
weight_scale_channel
def
requantize_with_max_scale
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Max scale to be used for requanitzation.
max_w_scale
=
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
# If unfused checkpoint, need requanize with the single scale.
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
weight
[
start
:
end
,
:],
_
=
ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
start
=
end
return
max_w_scale
,
weight
def
apply_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
True
,
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if
bias
is
None
and
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
def
apply_int8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
bias
is
not
None
:
raise
NotImplementedError
(
"W8A8 with int8 does not yet support bias."
)
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment