Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c342570
Unverified
Commit
5c342570
authored
May 16, 2024
by
alexm-nm
Committed by
GitHub
May 16, 2024
Browse files
Add marlin unit tests and marlin benchmark script (#4815)
parent
973617ae
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
736 additions
and
0 deletions
+736
-0
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+183
-0
benchmarks/kernels/benchmark_shapes.py
benchmarks/kernels/benchmark_shapes.py
+75
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+158
-0
vllm/model_executor/layers/quantization/utils/__init__.py
vllm/model_executor/layers/quantization/utils/__init__.py
+0
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+174
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+146
-0
No files found.
benchmarks/kernels/benchmark_marlin.py
0 → 100644
View file @
5c342570
import
argparse
import
torch
import
torch.utils.benchmark
as
benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
def
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
num_bits
,
group_size
,
size_m
,
size_k
,
size_n
):
label
=
"Quant Matmul"
sub_label
=
(
"{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})"
.
format
(
model
,
act_order
,
is_k_full
,
num_bits
,
group_size
,
size_m
,
size_k
,
size_n
))
print
(
f
"Testing:
{
sub_label
}
"
)
a
=
torch
.
randn
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
()
b
=
torch
.
rand
(
size_k
,
size_n
).
to
(
torch
.
half
).
cuda
()
a_tmp
=
(
torch
.
zeros
(
size_m
,
size_k
).
to
(
torch
.
half
).
cuda
())
# Marlin quant
(
marlin_w_ref
,
marlin_q_w
,
marlin_s
,
marlin_g_idx
,
marlin_sort_indices
,
marlin_rand_perm
,
)
=
marlin_quantize
(
b
,
num_bits
,
group_size
,
act_order
)
# GPTQ quant
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
quantize_weights
(
b
,
num_bits
,
group_size
,
act_order
)
q_w_gptq
=
gptq_pack
(
q_w
,
num_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
if
act_order
:
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
)
globals
=
{
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
"num_bits"
:
num_bits
,
"group_size"
:
group_size
,
"size_m"
:
size_m
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"is_k_full"
:
is_k_full
,
"a"
:
a
,
"a_tmp"
:
a_tmp
,
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"marlin_workspace"
:
marlin_workspace
,
}
min_run_time
=
1
# Warmup pytorch
for
i
in
range
(
5
):
torch
.
matmul
(
a
,
marlin_w_ref
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"torch.matmul(a, marlin_w_ref)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"pytorch_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
results
.
append
(
benchmark
.
Timer
(
stmt
=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
=
[]
for
model
in
args
.
models
:
for
layer
in
WEIGHT_SHAPES
[
model
]:
size_k
=
layer
[
0
]
size_n
=
layer
[
1
]
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
act_order
in
ACT_ORDER_OPTS
:
for
is_k_full
in
K_FULL_OPTS
:
for
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
for
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
if
len
(
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
:
continue
# For act_order, the group_size must be less than
# size_k
if
act_order
and
(
group_size
==
size_k
or
group_size
==
-
1
):
continue
for
size_m
in
args
.
batch_sizes
:
bench_run
(
results
,
model
,
act_order
,
is_k_full
,
num_bits
,
group_size
,
size_m
,
size_k
,
size_n
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
#
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark Marlin across specified models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_shapes.py
0 → 100644
View file @
5c342570
WEIGHT_SHAPES
=
{
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
"mistralai/Mistral-7B-v0.1/TP1"
:
[
[
4096
,
6144
],
[
4096
,
4096
],
[
4096
,
28672
],
[
14336
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP2"
:
[
[
4096
,
3072
],
[
2048
,
4096
],
[
4096
,
14336
],
[
7168
,
4096
],
],
"mistralai/Mistral-7B-v0.1/TP4"
:
[
[
4096
,
1536
],
[
1024
,
4096
],
[
4096
,
7168
],
[
3584
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP1"
:
[
[
4096
,
12288
],
[
4096
,
4096
],
[
4096
,
22016
],
[
11008
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP2"
:
[
[
4096
,
6144
],
[
2048
,
4096
],
[
4096
,
11008
],
[
5504
,
4096
],
],
"meta-llama/Llama-2-7b-hf/TP4"
:
[
[
4096
,
3072
],
[
1024
,
4096
],
[
4096
,
5504
],
[
2752
,
4096
],
],
"meta-llama/Llama-2-13b-hf/TP1"
:
[
[
5120
,
15360
],
[
5120
,
5120
],
[
5120
,
27648
],
[
13824
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP2"
:
[
[
5120
,
7680
],
[
2560
,
5120
],
[
5120
,
13824
],
[
6912
,
5120
],
],
"meta-llama/Llama-2-13b-hf/TP4"
:
[
[
5120
,
3840
],
[
1280
,
5120
],
[
5120
,
6912
],
[
3456
,
5120
],
],
"meta-llama/Llama-2-70b-hf/TP1"
:
[
[
8192
,
10240
],
[
8192
,
8192
],
[
8192
,
57344
],
[
28672
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP2"
:
[
[
8192
,
5120
],
[
4096
,
8192
],
[
8192
,
28672
],
[
14336
,
8192
],
],
"meta-llama/Llama-2-70b-hf/TP4"
:
[
[
8192
,
2560
],
[
2048
,
8192
],
[
8192
,
14336
],
[
7168
,
8192
],
],
}
tests/kernels/test_marlin_gemm.py
0 → 100644
View file @
5c342570
"""Tests for the marlin kernel.
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
is_marlin_supported
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_CHUNKS
=
[
128
,
256
]
N_CHUNKS
=
[
64
,
128
,
256
]
MNK_FACTORS
=
[
(
1
,
1
,
1
),
(
1
,
4
,
8
),
(
1
,
7
,
5
),
(
1
,
7
*
4
,
5
*
1
),
(
13
,
17
,
67
),
(
26
,
37
,
13
),
(
67
,
13
,
11
),
]
def
rand_data
(
shape
):
data
=
torch
.
rand
(
shape
).
to
(
torch
.
half
).
cuda
()
return
data
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
act_order
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
==
size_k
:
return
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Create input
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
b_weight
,
num_bits
,
group_size
,
act_order
)
# Pack to GPTQ format
q_w_gptq
=
gptq_pack
(
q_w
,
num_bits
,
size_k
,
size_n
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b_weight
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Pack to Marlin format
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
q_w_gptq
,
sort_indices
,
size_k
,
size_n
,
num_bits
,
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
def
test_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
act_order
,
is_k_full
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
==
size_k
:
return
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
num_bits
,
group_size
,
act_order
)
workspace
=
MarlinWorkspace
(
size_n
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
output
,
output_ref
,
rtol
=
1e-2
)
vllm/model_executor/layers/quantization/utils/__init__.py
0 → 100644
View file @
5c342570
vllm/model_executor/layers/quantization/utils/marlin_utils.py
0 → 100644
View file @
5c342570
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_TILE
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
_get_perms
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
_perm
=
{}
_scale_perm
=
{}
_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
_get_perms
(
num_bits
)
_perm
[
num_bits
]
=
perm
_scale_perm
[
num_bits
]
=
scale_perm
_scale_perm_single
[
num_bits
]
=
scale_perm_single
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
(
(
-
1
,
_perm
[
num_bits
].
numel
()))[:,
_perm
[
num_bits
]].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
_scale_perm
[
num_bits
])))[:,
_scale_perm
[
num_bits
]]
else
:
s
=
s
.
reshape
(
(
-
1
,
len
(
_scale_perm_single
[
num_bits
])))[:,
_scale_perm_single
[
num_bits
]]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
):
assert
(
out_features
%
GPTQ_MARLIN_MIN_THREAD_N
==
0
),
(
"out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}"
.
format
(
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
))
max_workspace_size
=
((
out_features
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
vllm/model_executor/layers/quantization/utils/quant_utils.py
0 → 100644
View file @
5c342570
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,
),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_k
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
//
pack_factor
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[
i
::
pack_factor
,
:]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment