Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e312e52b
Unverified
Commit
e312e52b
authored
Oct 17, 2024
by
Lucas Wilkinson
Committed by
GitHub
Oct 17, 2024
Browse files
[Kernel] Add Exllama as a backend for compressed-tensors (#9395)
parent
dbfa8d31
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
173 additions
and
16 deletions
+173
-16
vllm/envs.py
vllm/envs.py
+9
-0
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
...el_executor/layers/quantization/kernels/MPLinearKernel.py
+4
-0
vllm/model_executor/layers/quantization/kernels/__init__.py
vllm/model_executor/layers/quantization/kernels/__init__.py
+5
-3
vllm/model_executor/layers/quantization/kernels/exllama.py
vllm/model_executor/layers/quantization/kernels/exllama.py
+140
-0
vllm/model_executor/layers/quantization/kernels/machete.py
vllm/model_executor/layers/quantization/kernels/machete.py
+7
-7
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+6
-6
vllm/scalar_type.py
vllm/scalar_type.py
+2
-0
No files found.
vllm/envs.py
View file @
e312e52b
...
...
@@ -66,6 +66,7 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
def
get_default_cache_root
():
...
...
@@ -430,6 +431,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1"
:
lambda
:
os
.
environ
.
get
(
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1"
,
"0"
)
==
"1"
,
# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS"
:
lambda
:
[]
if
"VLLM_DISABLED_KERNELS"
not
in
os
.
environ
else
os
.
environ
[
"VLLM_DISABLED_KERNELS"
].
split
(
","
),
}
# end-env-vars-definition
...
...
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
View file @
e312e52b
...
...
@@ -42,6 +42,10 @@ class MPLinearKernel(ABC):
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
w_s_name
=
w_s_param_name
if
c
.
zero_points
:
assert
w_zp_param_name
is
not
None
if
c
.
has_g_idx
:
assert
w_gidx_param_name
is
not
None
self
.
w_zp_name
=
w_zp_param_name
self
.
w_gidx_name
=
w_gidx_param_name
...
...
vllm/model_executor/layers/quantization/kernels/__init__.py
View file @
e312e52b
import
os
from
typing
import
List
,
Optional
,
Type
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.kernels.exllama
import
(
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.machete
import
(
MacheteLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.marlin
import
(
...
...
@@ -13,6 +15,7 @@ from vllm.platforms import current_platform
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
MacheteLinearKernel
,
MarlinLinearKernel
,
ExllamaLinearKernel
,
]
...
...
@@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
failure_reasons
=
[]
for
kernel
in
_POSSIBLE_KERNELS
:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
)
\
.
split
(
","
):
if
kernel
.
__name__
in
envs
.
VLLM_DISABLED_KERNELS
:
failure_reasons
.
append
(
f
'
{
kernel
.
__name__
}
disabled by environment variable'
)
continue
...
...
vllm/model_executor/layers/quantization/kernels/exllama.py
0 → 100644
View file @
e312e52b
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_quantized_values_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
vllm.scalar_type
import
scalar_types
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
ExllamaLinearKernel
(
MPLinearKernel
):
SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
return
False
,
"Act reordering currently not supported by Exllama, "
\
"when the input features are partitioned across "
\
"devices"
if
c
.
partition_weight_shape
[
1
]
%
(
32
//
c
.
weight_type
.
size_bits
)
!=
0
:
return
False
,
"Output features must be a multiple of the pack "
\
"factor (32 / num_bits) so that we can correctly "
\
"pack the zero points"
if
c
.
act_type
!=
torch
.
float16
:
return
False
,
"Exllama only supports float16 activations"
if
c
.
weight_type
not
in
cls
.
SUPPORTED_QUANT_TYPES
:
return
False
,
f
"Quant type (
{
c
.
weight_type
}
) not supported by "
\
"Exllama, supported types are: "
\
f
"
{
cls
.
SUPPORTED_QUANT_TYPES
}
"
if
c
.
full_weight_shape
[
0
]
%
c
.
group_size
!=
0
:
return
False
,
f
"Group size (
{
c
.
group_size
}
) does not evenly divide"
\
" the number of input features "
\
f
"(
{
c
.
full_weight_shape
[
0
]
}
)"
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
# For Exllama, we need to set a zero-point tensor if there is not one
if
not
c
.
zero_points
:
self
.
w_zp_name
=
"qzeros"
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
groups
=
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
out_features
=
c
.
partition_weight_shape
[
1
]
if
c
.
weight_type
.
has_bias
():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros
=
torch
.
full
((
groups
,
out_features
),
c
.
weight_type
.
bias
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
raise
NotImplementedError
(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference"
)
zeros
=
pack_quantized_values_into_int32
(
zeros
,
c
.
weight_type
,
packed_dim
=
1
)
setattr
(
layer
,
self
.
w_zp_name
,
torch
.
nn
.
Parameter
(
zeros
,
requires_grad
=
False
))
if
c
.
has_g_idx
:
def
transform_w_g_idx
(
x
):
# Exllama wants the permutation array instead of the group
# indices
return
torch
.
argsort
(
x
).
to
(
torch
.
int
)
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
transform_w_g_idx
)
else
:
self
.
w_gidx_name
=
"g_idx"
empty_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
0
,
),
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
setattr
(
layer
,
self
.
w_gidx_name
,
empty_g_idx
)
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
assert
self
.
w_gidx_name
is
not
None
g_idx
=
getattr
(
layer
,
self
.
w_gidx_name
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x_cont
=
x
.
data
.
contiguous
()
ops
.
gptq_shuffle
(
x_cont
,
g_idx
,
c
.
weight_type
.
size_bits
)
return
x_cont
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
x
.
data
.
contiguous
()
return
x
.
to
(
dtype
=
c
.
act_type
)
# Repack weights and scales for Machete
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
w_q
,
w_s
,
w_zp
,
w_g_idx
=
self
.
_get_weight_params
(
layer
)
assert
w_zp
is
not
None
,
"Zero points are required by Exllama"
assert
w_g_idx
is
not
None
,
"Group index is required by Exllama"
output
=
ops
.
gptq_gemm
(
x_2d
,
w_q
,
w_zp
,
w_s
,
w_g_idx
,
True
,
c
.
weight_type
.
size_bits
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/kernels/machete.py
View file @
e312e52b
...
...
@@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES
,
check_machete_supports_shape
,
query_machete_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_
weight
s_into_int32
,
unpack_
weight
s_into_int32
)
pack_
quantized_value
s_into_int32
,
unpack_
quantized_value
s_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
...
...
@@ -71,13 +71,13 @@ class MacheteLinearKernel(MPLinearKernel):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
if
c
.
has_g_idx
:
x_unpacked
=
unpack_
weight
s_into_int32
(
x
.
data
,
c
.
weight_type
,
packed_dim
=
0
)
x_unpacked
=
unpack_
quantized_value
s_into_int32
(
x
.
data
,
c
.
weight_type
,
packed_dim
=
0
)
x_perm
=
x_unpacked
[
perm
,
:]
x
.
data
=
pack_
weight
s_into_int32
(
x_perm
,
c
.
weight_type
,
packed_dim
=
0
)
x
.
data
=
pack_
quantized_value
s_into_int32
(
x_perm
,
c
.
weight_type
,
packed_dim
=
0
)
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
self
.
config
.
weight_type
)
return
x
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
e312e52b
...
...
@@ -20,9 +20,9 @@ FUSED_LAYER_NAME_MAPPING = {
}
def
pack_
weight
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
def
pack_
quantized_value
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
...
...
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
return
res
.
permute
(
inv_perm
)
def
unpack_
weight
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
def
unpack_
quantized_value
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
...
...
vllm/scalar_type.py
View file @
e312e52b
...
...
@@ -27,6 +27,8 @@ class scalar_types:
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
.
value
)
# "gptq" types
uint2b2
=
ScalarType
.
uint
(
2
,
2
)
uint3b4
=
ScalarType
.
uint
(
3
,
4
)
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment