Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
27ce8547
"vscode:/vscode.git/clone" did not exist on "f851b8426659d00a264eaab5ef0f69f70e592e37"
Unverified
Commit
27ce8547
authored
May 19, 2024
by
Alexander Matveev
Committed by
GitHub
May 19, 2024
Browse files
[Kernel] Add marlin_24 unit tests (#4901)
parent
f68470e8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
649 additions
and
103 deletions
+649
-103
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+74
-13
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+19
-8
vllm/model_executor/layers/quantization/utils/format_24.py
vllm/model_executor/layers/quantization/utils/format_24.py
+308
-0
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
...del_executor/layers/quantization/utils/marlin_24_perms.py
+58
-0
vllm/model_executor/layers/quantization/utils/marlin_perms.py
.../model_executor/layers/quantization/utils/marlin_perms.py
+58
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+132
-82
No files found.
tests/kernels/test_marlin_gemm.py
View file @
27ce8547
...
@@ -7,23 +7,32 @@ import torch
...
@@ -7,23 +7,32 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
is_marlin_supported
,
marlin_quantize
,
marlin_weights
)
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
gptq_pack
,
quantize_weights
,
sort_weights
)
ACT_ORDER_OPTS
=
[
False
,
True
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_CHUNKS
=
[
128
,
256
]
MARLIN_K_CHUNKS
=
[
128
]
N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_24_K_CHUNKS
=
[
128
]
MARLIN_24_N_CHUNKS
=
[
256
]
MNK_FACTORS
=
[
MNK_FACTORS
=
[
(
1
,
1
,
1
),
(
1
,
1
,
1
),
(
1
,
4
,
8
),
(
1
,
4
,
8
),
(
1
,
7
,
5
),
(
1
,
7
,
5
),
(
1
,
7
*
4
,
5
*
1
),
(
13
,
17
,
67
),
(
13
,
17
,
67
),
(
26
,
37
,
13
),
(
26
,
37
,
13
),
(
67
,
13
,
11
),
(
67
,
13
,
11
),
...
@@ -31,14 +40,13 @@ MNK_FACTORS = [
...
@@ -31,14 +40,13 @@ MNK_FACTORS = [
def
rand_data
(
shape
):
def
rand_data
(
shape
):
data
=
torch
.
rand
(
shape
).
to
(
torch
.
half
).
cuda
()
return
torch
.
randn
(
shape
,
dtype
=
torch
.
half
,
device
=
"cuda"
)
return
data
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
...
@@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Pack to Marlin format
# Pack to Marlin format
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
# Run Marlin repack GPU kernel
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
...
@@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_
K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_
N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
...
@@ -136,7 +145,8 @@ def test_marlin_gemm(
...
@@ -136,7 +145,8 @@ def test_marlin_gemm(
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
num_bits
,
group_size
,
act_order
)
b_weight
,
num_bits
,
group_size
,
act_order
)
workspace
=
MarlinWorkspace
(
size_n
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
a_input
,
...
@@ -155,4 +165,55 @@ def test_marlin_gemm(
...
@@ -155,4 +165,55 @@ def test_marlin_gemm(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
output
,
output_ref
,
rtol
=
1e-2
)
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_24_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_24_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
(
w_24_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b_weight
,
num_bits
,
group_size
)
workspace_24
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
output_ref
=
torch
.
matmul
(
a_input
,
w_24_ref
)
output
=
ops
.
gptq_marlin_24_gemm
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
27ce8547
...
@@ -12,6 +12,15 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -12,6 +12,15 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_24_TILE
=
16
GPTQ_MARLIN_24_MIN_THREAD_N
=
128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
16
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_SYM
=
[
True
]
class
GPTQMarlin24Config
(
QuantizationConfig
):
class
GPTQMarlin24Config
(
QuantizationConfig
):
"""Config class for Marlin24.
"""Config class for Marlin24.
...
@@ -25,15 +34,17 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -25,15 +34,17 @@ class GPTQMarlin24Config(QuantizationConfig):
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
group_size
=
group_size
if
self
.
weight_bits
!=
4
and
self
.
weight_bits
!=
8
:
# Verify
raise
ValueError
(
"weight_bits must be 4 or 8. Got = {}"
.
format
(
if
self
.
weight_bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
self
.
weight_bits
))
raise
ValueError
(
f
"Marlin_24 does not support weight_bits =
{
self
.
weight_bits
}
. "
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
raise
ValueError
(
"Currently, only
group
size
128 and -1 (channelwise)
"
f
"Marlin_24 does not support
group
_
size
=
{
self
.
group_size
}
.
"
"is supported for Marlin24, but got group_size of
"
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
f
"
{
self
.
group_size
}
"
)
"are supported.
"
)
# 4 Bits packed into 32 bit datatype.
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
...
...
vllm/model_executor/layers/quantization/utils/format_24.py
0 → 100644
View file @
27ce8547
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
):
dst_rows
=
torch
.
arange
(
0
,
m
,
device
=
device
)[:,
None
].
repeat
(
1
,
meta_ncols
)
dst_cols
=
torch
.
arange
(
0
,
meta_ncols
,
device
=
device
).
repeat
(
m
,
1
)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x
=
64
group_y
=
32
if
meta_dtype
.
itemsize
==
2
else
16
dst_rows
=
(
dst_rows
//
group_x
*
group_x
+
(
dst_rows
%
2
)
*
2
+
(
dst_rows
%
8
)
//
4
+
((
dst_rows
%
group_y
)
%
4
)
//
2
*
32
+
((
dst_rows
%
group_x
)
//
8
)
*
4
)
topright
=
((
dst_rows
%
2
==
0
)
&
(
dst_cols
%
2
==
1
)).
to
(
torch
.
int8
)
bottomleft
=
((
dst_rows
%
2
==
1
)
&
(
dst_cols
%
2
==
0
)).
to
(
torch
.
int8
)
dst_rows
+=
topright
-
bottomleft
dst_cols
-=
topright
-
bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave
=
2
cols_maj
=
dst_cols
//
interleave
cols_min
=
dst_cols
%
interleave
return
(
cols_maj
*
m
*
interleave
+
dst_rows
*
interleave
+
cols_min
).
view
(
-
1
)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def
sparse_semi_structured_from_dense_cutlass
(
dense
):
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
dense
.
shape
device
=
dense
.
device
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
int32
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
# noqa: E501
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
(
(
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
# Reorder meta tensor elements.
meta_reordered
=
meta
.
new_empty
(
(
m
*
meta_ncols
,
))
# type: ignore[possibly-undefined]
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta_reordered
.
scatter_
(
0
,
meta_offsets
,
meta
.
view
(
-
1
))
return
(
sparse
,
meta_reordered
.
view
(
m
,
meta_ncols
))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def
sparse_semi_structured_to_dense_cutlass
(
sparse
,
meta_reordered
):
if
sparse
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional sparse tensor, got
{
sparse
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
sparse
.
shape
device
=
sparse
.
device
if
meta_reordered
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional meta tensor, got
{
meta_reordered
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
if
meta_reordered
.
device
!=
device
:
raise
RuntimeError
(
f
"Expected meta matrix to be on
{
device
}
device, got matrix on
{
meta_reordered
.
device
}
device"
# noqa: E501
)
meta_dtype
=
meta_reordered
.
dtype
if
meta_dtype
not
in
(
torch
.
int16
,
torch
.
int32
):
raise
RuntimeError
(
f
"Invalid datatype
{
meta_dtype
}
of meta matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
ksparse
=
4
if
sparse
.
dtype
!=
torch
.
float
else
2
meta_nrows
,
meta_ncols
=
meta_reordered
.
shape
if
meta_nrows
!=
m
:
raise
RuntimeError
(
f
"Number of rows of meta matrix
{
meta_nrows
}
must be equal to number of columns of spase matrix
{
m
}
"
# noqa: E501
)
if
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
!=
2
*
k
:
raise
RuntimeError
(
f
"Number of columns of sparse matrix
{
k
}
different from the
{
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
//
2
}
, "
# noqa: E501
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta
=
torch
.
gather
(
meta_reordered
.
view
(
-
1
),
0
,
meta_offsets
).
view
(
m
,
meta_ncols
)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2
=
torch
.
empty
(
(
m
,
meta_ncols
,
2
*
quadbits_per_meta_elem
),
dtype
=
meta_dtype
,
device
=
device
,
)
if
quadbits_per_meta_elem
==
4
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
elif
quadbits_per_meta_elem
==
8
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
meta_2
[:,
:,
8
]
=
(
meta
>>
16
)
&
0b11
meta_2
[:,
:,
9
]
=
(
meta
>>
18
)
&
0b11
meta_2
[:,
:,
10
]
=
(
meta
>>
20
)
&
0b11
meta_2
[:,
:,
11
]
=
(
meta
>>
22
)
&
0b11
meta_2
[:,
:,
12
]
=
(
meta
>>
24
)
&
0b11
meta_2
[:,
:,
13
]
=
(
meta
>>
26
)
&
0b11
meta_2
[:,
:,
14
]
=
(
meta
>>
28
)
&
0b11
meta_2
[:,
:,
15
]
=
(
meta
>>
30
)
&
0b11
dense_offsets
=
meta_2
.
view
(
-
1
)
+
(
torch
.
arange
(
0
,
2
*
m
*
k
//
ksparse
,
device
=
device
)
*
4
).
view
(
-
1
,
1
).
repeat
(
1
,
2
).
view
(
-
1
)
dense
=
torch
.
zeros
((
m
*
2
*
k
,
),
dtype
=
sparse
.
dtype
,
device
=
device
)
if
sparse
.
dtype
!=
torch
.
float
:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense
.
scatter_
(
0
,
dense_offsets
,
sparse
.
reshape
(
-
1
))
else
:
dense
.
view
(
torch
.
half
).
scatter_
(
0
,
dense_offsets
,
sparse
.
view
(
torch
.
half
).
view
(
-
1
))
return
dense
.
view
(
m
,
2
*
k
)
def
mask_creator
(
tensor
):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N
=
2
M
=
4
mask
=
None
# for i, tensor in enumerate(tensors):
if
tensor
.
numel
()
%
M
!=
0
:
raise
ValueError
(
f
"Tensor of size
{
tensor
.
shape
}
can't be evenly divided into "
f
"
{
M
}
groups"
)
num_groups
=
tensor
.
numel
()
//
M
# N:M sparsity for linear layers
tensor_temp
=
tensor
.
detach
().
abs
().
reshape
(
num_groups
,
M
)
index
=
torch
.
argsort
(
tensor_temp
,
dim
=
1
)[:,
:
int
(
M
-
N
)]
w_b
=
torch
.
ones
(
tensor_temp
.
shape
,
device
=
tensor_temp
.
device
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
0 → 100644
View file @
27ce8547
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
=
{}
marlin_24_scale_perm
=
{}
marlin_24_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
vllm/model_executor/layers/quantization/utils/marlin_perms.py
0 → 100644
View file @
27ce8547
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
=
{}
marlin_scale_perm
=
{}
marlin_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
27ce8547
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
import
random
import
numpy
import
numpy
import
torch
import
torch
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.utils.format_24
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_TILE
)
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
vllm.model_executor.layers.quantization.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
get_pack_factor
,
quantize_weights
,
sort_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
def
is_marlin_supported
():
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
return
__cuda_arch
[
0
]
>=
8
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
_get_perms
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
_perm
=
{}
_scale_perm
=
{}
_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
_get_perms
(
num_bits
)
_perm
[
num_bits
]
=
perm
_scale_perm
[
num_bits
]
=
scale_perm
_scale_perm_single
[
num_bits
]
=
scale_perm_single
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
...
@@ -83,15 +32,14 @@ def marlin_permute_weights(q_w,
...
@@ -83,15 +32,14 @@ def marlin_permute_weights(q_w,
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
(
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
(
-
1
,
_perm
[
num_bits
].
numel
()))[:,
_perm
[
num_bits
]].
reshape
(
q_w
.
shape
)
return
q_w
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
):
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
pack_factor
=
get_pack_factor
(
num_bits
)
...
@@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
...
@@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
...
@@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
...
@@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
return
q_packed
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
_scale_perm
[
num_bits
])))[:,
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
_scale_perm
[
num_bits
]]
else
:
else
:
s
=
s
.
reshape
(
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
(
-
1
,
len
(
_scale_perm_single
[
num_bits
])))[:,
_scale_perm_single
[
num_bits
]]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
return
s
...
@@ -148,8 +92,11 @@ def marlin_quantize(
...
@@ -148,8 +92,11 @@ def marlin_quantize(
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
)
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
...
@@ -159,15 +106,118 @@ def marlin_quantize(
...
@@ -159,15 +106,118 @@ def marlin_quantize(
return
res_list
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
):
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
GPTQ_MARLIN_MIN_THREAD_N
==
0
),
(
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by
GPTQ_MARLIN_MIN_THREAD_N = {}"
"out_features = {} is undivisible by
min_thread_n = {}"
.
format
(
.
format
(
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
))
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
GPTQ_MARLIN_MAX_PARALLEL
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment