Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0fa14907
Unverified
Commit
0fa14907
authored
Aug 08, 2024
by
Siyuan Liu
Committed by
GitHub
Aug 08, 2024
Browse files
[TPU] Add Load-time W8A16 quantization for TPU Backend (#7005)
parent
5923532e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
8 deletions
+135
-8
vllm/config.py
vllm/config.py
+6
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/tpu_int8.py
vllm/model_executor/layers/quantization/tpu_int8.py
+118
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+9
-8
No files found.
vllm/config.py
View file @
0fa14907
...
...
@@ -244,6 +244,7 @@ class ModelConfig:
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
]
tpu_supported_quantization
=
[
"tpu_int8"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
@@ -282,6 +283,11 @@ class ModelConfig:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
is_tpu
(
)
and
self
.
quantization
not
in
tpu_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in TPU Backend."
)
if
self
.
quantization
not
in
optimized_quantization_methods
:
logger
.
warning
(
"%s quantization is not fully "
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
0fa14907
...
...
@@ -22,11 +22,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
...
...
vllm/model_executor/layers/quantization/tpu_int8.py
0 → 100644
View file @
0fa14907
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"none"
]
class
Int8TpuConfig
(
QuantizationConfig
):
"""Int8 Quantization Config class for TPU Backend."""
def
__init__
(
self
,
activation_scheme
:
str
=
"none"
,
)
->
None
:
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
def
get_name
(
self
)
->
str
:
return
"tpu_int8"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"This function should not be called with TPU Backend"
)
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Int8TpuConfig"
:
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
activation_scheme
=
activation_scheme
)
def
get_quant_method
(
self
,
layer
:
Module
,
prefix
:
str
)
->
Optional
[
"TPUInt8LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
TPUInt8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
TPUInt8LinearMethod
(
LinearMethodBase
):
"""Int8 Linear method for TPU Quant. """
def
__init__
(
self
,
quant_config
:
Int8TpuConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
def
_quantize_weight
(
self
,
weight
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
weight_dtype
=
weight
.
dtype
weight
=
weight
.
cpu
().
to
(
torch
.
float32
)
n_bit
=
8
eps
=
1e-5
max_int
=
2
**
(
n_bit
-
1
)
-
1
min_int
=
-
(
2
**
(
n_bit
-
1
))
max_val
=
weight
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
eps
)
qscale
=
max_val
/
max_int
qweight
=
torch
.
clamp
(
torch
.
round
(
weight
*
(
1.0
/
qscale
)),
min_int
,
max_int
).
to
(
torch
.
int8
)
qscale
=
qscale
.
squeeze
().
to
(
weight_dtype
)
return
qweight
,
qscale
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
device
=
layer
.
weight
.
device
qweight
,
qscale
=
self
.
_quantize_weight
(
layer
.
weight
)
qweight
=
qweight
.
to
(
device
)
qscale
=
qscale
.
to
(
device
)
layer
.
weight
=
Parameter
(
qweight
,
requires_grad
=
False
)
layer
.
scale
=
Parameter
(
qscale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
try
:
import
torch_xla.experimental.xla_quantized_matmul
# noqa: F401
except
ImportError
as
err
:
raise
ImportError
(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html "
# noqa: E501
"to run vLLM on TPU."
)
from
err
weight
=
layer
.
weight
scale
=
layer
.
scale
out
=
torch
.
ops
.
xla
.
quantized_matmul
(
x
,
weight
,
scale
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
vllm/model_executor/model_loader/loader.py
View file @
0fa14907
...
...
@@ -94,12 +94,13 @@ def _get_quantization_config(
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
if
not
is_tpu
():
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not
"
"
supported for the current GPU. "
f
"The quantization method
{
model_config
.
quantization
}
"
"is not
supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment