Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
172d1cd2
Unverified
Commit
172d1cd2
authored
Sep 27, 2024
by
Luka Govedič
Committed by
GitHub
Sep 27, 2024
Browse files
[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method (#7271)
parent
a9b15c60
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
124 additions
and
21 deletions
+124
-21
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
...eta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
+6
-1
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+27
-9
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+10
-6
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+52
-3
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+17
-2
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
0 → 100644
View file @
172d1cd2
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.764
-
name
:
"
exact_match,flexible-extract"
value
:
0.764
limit
:
250
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
172d1cd2
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base-FP8.yaml
...
...
.buildkite/lm-eval-harness/test_lm_eval_correctness.py
View file @
172d1cd2
...
...
@@ -49,10 +49,15 @@ def test_lm_eval_correctness():
results
=
launch_lm_eval
(
eval_config
)
# Confirm scores match ground truth.
success
=
True
for
task
in
eval_config
[
"tasks"
]:
for
metric
in
task
[
"metrics"
]:
ground_truth
=
metric
[
"value"
]
measured_value
=
results
[
"results"
][
task
[
"name"
]][
metric
[
"name"
]]
print
(
f
'
{
task
[
"name"
]
}
|
{
metric
[
"name"
]
}
: '
f
'ground_truth=
{
ground_truth
}
| measured=
{
measured_value
}
'
)
assert
numpy
.
isclose
(
ground_truth
,
measured_value
,
rtol
=
RTOL
)
success
=
success
and
numpy
.
isclose
(
ground_truth
,
measured_value
,
rtol
=
RTOL
)
# Assert at the end, print all scores even on failure for debugging.
assert
success
tests/quantization/test_compressed_tensors.py
View file @
172d1cd2
...
...
@@ -2,6 +2,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -14,14 +15,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType
)
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"tensor"
,
QuantizationType
.
INT
,
2560
),
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
True
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
"channel"
,
QuantizationType
.
INT
,
2560
),
])
QuantizationType
.
INT
,
2560
,
True
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
False
)])
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
=
model_args
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
...
...
@@ -31,6 +34,18 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# assert zp for symmetric and asymmetric cases
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
if
is_symmetric
:
return
zp
is
None
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
...
...
@@ -69,9 +84,12 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
"channel"
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym"
,
"channel"
),
])
def
test_compressed_tensors_w8a8_dyna
n
mic_per_token
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
model_path
,
strategy
=
model_args
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
172d1cd2
...
...
@@ -138,10 +138,11 @@ class CompressedTensorsConfig(QuantizationConfig):
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
is_tensor
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_static
=
not
weight_quant
.
dynamic
and
not
input_quant
.
dynamic
return
is_8_bits
and
is_tensor
and
is_symmetric
and
is_static
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_tensor
and
weight_quant
.
symmetric
and
is_static
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
...
...
@@ -151,10 +152,11 @@ class CompressedTensorsConfig(QuantizationConfig):
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
is_token
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
return
is_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
...
...
@@ -265,12 +267,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
True
)
is_static_input_scheme
=
True
,
input_symmetric
=
input_quant
.
symmetric
)
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
False
)
is_static_input_scheme
=
False
,
input_symmetric
=
input_quant
.
symmetric
)
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
172d1cd2
...
...
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional
import
torch
from
torch.nn
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
...
...
@@ -14,12 +15,16 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ModelWeightParameter
,
PerTensorScaleParameter
)
logger
=
init_logger
(
__name__
)
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -46,10 +51,43 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
requires_grad
=
False
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
if
self
.
input_symmetric
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_zero_point
=
None
else
:
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
azps
=
layer
.
input_zero_point
.
to
(
dtype
=
torch
.
int32
)
range_max
=
(
layer
.
input_scale
*
(
int8_traits
.
max
-
azps
)).
max
()
range_min
=
(
layer
.
input_scale
*
(
int8_traits
.
min
-
azps
)).
min
()
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
layer
.
input_scale
=
Parameter
(
scale
,
requires_grad
=
False
)
# AZP loaded as int8 but used as int32
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
layer
.
input_zero_point
=
Parameter
(
azp
,
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
layer
.
input_zero_point
=
None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if
not
self
.
input_symmetric
:
layer
.
azp_adj
=
layer
.
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
else
:
layer
.
azp_adj
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
...
...
@@ -90,6 +128,15 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
@@ -97,4 +144,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_zero_point
=
layer
.
input_zero_point
,
azp_adj
=
layer
.
azp_adj
,
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
172d1cd2
...
...
@@ -191,13 +191,28 @@ def apply_int8_linear(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
azp_adj
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
symmetric
=
azp_adj
is
None
x_q
,
x_scale
,
x_zp
=
ops
.
scaled_int8_quant
(
input
,
input_scale
,
input_zero_point
,
symmetric
=
symmetric
)
if
x_zp
is
not
None
:
return
ops
.
cutlass_scaled_mm_azp
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
azp_adj
=
azp_adj
,
azp
=
x_zp
,
bias
=
bias
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment