Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
93d2b9fe
Unverified
Commit
93d2b9fe
authored
Jul 24, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 24, 2024
Browse files
Split up `layers.marlin` into several files (#2292)
The marlin.py file was getting large, split it up.
parent
86422506
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
917 additions
and
1 deletion
+917
-1
server/tests/utils/test_weights.py
server/tests/utils/test_weights.py
+4
-1
server/text_generation_server/layers/marlin/__init__.py
server/text_generation_server/layers/marlin/__init__.py
+20
-0
server/text_generation_server/layers/marlin/fp8.py
server/text_generation_server/layers/marlin/fp8.py
+140
-0
server/text_generation_server/layers/marlin/gptq.py
server/text_generation_server/layers/marlin/gptq.py
+266
-0
server/text_generation_server/layers/marlin/marlin.py
server/text_generation_server/layers/marlin/marlin.py
+346
-0
server/text_generation_server/layers/marlin/util.py
server/text_generation_server/layers/marlin/util.py
+141
-0
No files found.
server/tests/utils/test_weights.py
View file @
93d2b9fe
...
@@ -8,7 +8,10 @@ from text_generation_server.utils.weights import (
...
@@ -8,7 +8,10 @@ from text_generation_server.utils.weights import (
)
)
from
text_generation_server.layers.gptq
import
GPTQWeight
,
GPTQWeightsLoader
from
text_generation_server.layers.gptq
import
GPTQWeight
,
GPTQWeightsLoader
from
text_generation_server.layers.exl2
import
Exl2Weight
,
Exl2WeightsLoader
from
text_generation_server.layers.exl2
import
Exl2Weight
,
Exl2WeightsLoader
from
text_generation_server.layers.marlin
import
MarlinWeight
,
MarlinWeightsLoader
from
text_generation_server.layers.marlin.marlin
import
(
MarlinWeight
,
MarlinWeightsLoader
,
)
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Dict
,
Union
from
typing
import
List
,
Optional
,
Dict
,
Union
from
pathlib
import
Path
from
pathlib
import
Path
...
...
server/text_generation_server/layers/marlin/__init__.py
0 → 100644
View file @
93d2b9fe
from
typing
import
List
,
Tuple
import
torch
from
text_generation_server.layers.marlin.fp8
import
GPTQMarlinFP8Linear
from
text_generation_server.layers.marlin.gptq
import
(
GPTQMarlinLinear
,
GPTQMarlinWeight
,
can_use_gptq_marlin
,
repack_gptq_for_marlin
,
)
from
text_generation_server.layers.marlin.marlin
import
MarlinWeightsLoader
__all__
=
[
"GPTQMarlinFP8Linear"
,
"GPTQMarlinLinear"
,
"GPTQMarlinWeight"
,
"MarlinWeightsLoader"
,
"can_use_gptq_marlin"
,
"repack_gptq_for_marlin"
,
]
server/text_generation_server/layers/marlin/fp8.py
0 → 100644
View file @
93d2b9fe
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
loguru
import
logger
from
text_generation_server.layers.fp8
import
fp8_quantize
from
text_generation_server.layers.marlin.gptq
import
_check_valid_shape
from
text_generation_server.layers.marlin.util
import
(
_check_marlin_kernels
,
permute_scales
,
)
from
text_generation_server.utils.log
import
log_once
try
:
import
marlin_kernels
except
ImportError
:
marlin_kernels
=
None
MARLIN_TILE_SIZE
=
16
class
GPTQMarlinFP8Linear
(
nn
.
Module
):
"""
FP8 GPTQ-Marlin linear layer.
"""
def
__init__
(
self
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
None
:
super
().
__init__
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
log_once
(
logger
.
info
,
"GPU does not support FP8, using Marlin FP8 kernel"
)
scales
=
scales
.
unsqueeze
(
0
)
if
scales
.
shape
[
1
]
==
1
:
out_features
,
in_features
=
qweight
.
shape
scales
=
scales
.
repeat
(
1
,
out_features
)
qweight
,
scales
=
repack_fp8_for_marlin
(
qweight
,
scales
)
in_features
=
qweight
.
shape
[
0
]
*
MARLIN_TILE_SIZE
out_features
=
scales
.
shape
[
1
]
_check_valid_shape
(
in_features
=
in_features
,
out_features
=
out_features
)
self
.
qweight
=
qweight
self
.
scales
=
scales
self
.
bias
=
bias
if
bias
is
not
None
else
None
self
.
workspace
=
torch
.
zeros
(
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
@
classmethod
def
from_unquant
(
cls
,
weight
,
bias
,
dtype
):
qweight
,
scales
=
fp8_quantize
(
weight
)
return
cls
(
qweight
=
qweight
,
scales
=
scales
.
to
(
dtype
),
bias
=
bias
)
@
classmethod
def
from_fp8
(
cls
,
weight
,
scale
,
_input_scale
,
bias
,
dtype
):
return
cls
(
qweight
=
weight
,
scales
=
scale
.
to
(
dtype
),
bias
=
bias
)
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
marlin_kernels
is
not
None
A_flat
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
])
C
=
marlin_kernels
.
fp8_marlin_gemm
(
A_flat
,
self
.
qweight
,
self
.
scales
,
self
.
workspace
,
8
,
A_flat
.
shape
[
0
],
self
.
scales
.
shape
[
1
],
A_flat
.
shape
[
1
],
)
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
scales
.
shape
[
1
],))
if
self
.
bias
is
not
None
:
C
+=
self
.
bias
return
C
def
pack_fp8_as_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
if
fp8_tensor
.
shape
[
0
]
%
4
!=
0
:
raise
ValueError
(
f
"Leading tensor dimension is not divisable by 4:
{
fp8_tensor
.
shape
[
0
]
}
"
)
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
torch
.
zeros
(
fp8_tensor
.
shape
[
0
]
//
4
,
fp8_tensor
.
shape
[
1
],
dtype
=
torch
.
int32
,
device
=
fp8_tensor
.
device
,
)
for
i
in
range
(
4
):
packed
.
bitwise_or_
(
byte_tensor
[:,
i
].
to
(
torch
.
int32
)
<<
i
*
8
)
return
packed
def
repack_fp8_for_marlin
(
weight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features
,
in_features
=
weight
.
shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight
=
pack_fp8_as_int32
(
weight
.
t
())
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
repacked
=
marlin_kernels
.
gptq_marlin_repack
(
qweight
,
perm
,
in_features
,
out_features
,
8
)
scales
=
permute_scales
(
scales
)
return
repacked
,
scales
server/text_generation_server/layers/marlin/gptq.py
0 → 100644
View file @
93d2b9fe
from
dataclasses
import
dataclass
from
typing
import
Optional
import
numpy
import
torch
import
torch.nn
as
nn
from
loguru
import
logger
from
text_generation_server.layers.marlin.util
import
(
_check_marlin_kernels
,
marlin_zero_points
,
permute_scales
,
unpack_cols
,
)
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.log
import
log_once
from
text_generation_server.utils.weights
import
Weight
try
:
import
marlin_kernels
except
ImportError
:
marlin_kernels
=
None
try
:
major
,
_minor
=
torch
.
cuda
.
get_device_capability
()
has_sm_8_0
=
major
>=
8
except
Exception
:
has_sm_8_0
=
False
GPTQ_MARLIN_BITS
=
[
4
,
8
]
GPTQ_MARLIN_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_TILE_SIZE
=
16
def
can_use_gptq_marlin
(
*
,
bits
:
int
,
groupsize
:
int
,
quant_method
:
str
,
quantize
:
str
,
sym
:
bool
)
->
bool
:
return
(
SYSTEM
==
"cuda"
and
marlin_kernels
is
not
None
and
has_sm_8_0
and
quantize
in
{
"awq"
,
"gptq"
}
and
quant_method
in
{
"awq"
,
"gptq"
}
and
bits
in
GPTQ_MARLIN_BITS
and
groupsize
in
GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
and
(
sym
or
quant_method
==
"awq"
)
)
@
dataclass
class
GPTQMarlinWeight
(
Weight
):
"""
Repacked GPTQ Marlin weights.
"""
qweight
:
torch
.
Tensor
qzeros
:
torch
.
Tensor
scales
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
perm
:
torch
.
Tensor
bits
:
int
is_full_k
:
bool
def
__post_init__
(
self
):
assert
self
.
qweight
.
dtype
==
torch
.
int32
assert
self
.
scales
.
dtype
==
torch
.
float16
assert
self
.
g_idx
.
dtype
==
torch
.
int32
assert
self
.
perm
.
dtype
==
torch
.
int32
def
get_linear
(
self
,
bias
:
torch
.
Tensor
):
return
GPTQMarlinLinear
(
weight
=
self
,
bias
=
bias
,
)
def
repack_gptq_for_marlin
(
*
,
qweight
:
torch
.
Tensor
,
qzeros
:
Optional
[
torch
.
Tensor
],
scales
:
torch
.
Tensor
,
g_idx
:
Optional
[
torch
.
Tensor
],
bits
:
int
,
desc_act
:
bool
,
groupsize
:
int
,
quant_method
:
str
,
sym
:
bool
,
sharded_infeatures
:
bool
,
)
->
GPTQMarlinWeight
:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
if
bits
not
in
GPTQ_MARLIN_BITS
:
supported_bits
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_BITS
)
raise
RuntimeError
(
f
"Repacking
{
bits
}
-bit GPTQ weights as Marlin is not supported, must be one of:
{
supported_bits
}
"
)
if
groupsize
not
in
GPTQ_MARLIN_GROUP_SIZES
:
supported_sizes
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_GROUP_SIZES
)
raise
RuntimeError
(
f
"Repacking GPTQ weights with group size
{
groupsize
}
as Marlin is not supported, must be one of:
{
supported_sizes
}
"
)
if
not
(
sym
or
quant_method
==
"awq"
):
raise
RuntimeError
(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once
(
logger
.
info
,
f
"Converting
{
quant_method
}
model to Marlin packing format."
)
weights_per_int
=
32
//
bits
in_features
=
qweight
.
shape
[
0
]
out_features
=
qweight
.
shape
[
1
]
# AWQ uses column packing, GPTQ uses row packing
if
quant_method
==
"awq"
:
out_features
*=
weights_per_int
else
:
in_features
*=
weights_per_int
if
in_features
%
groupsize
!=
0
:
raise
ValueError
(
f
"Number of input features (
{
in_features
}
) not divisible by group size (
{
groupsize
}
)"
)
if
g_idx
is
not
None
and
desc_act
and
groupsize
!=
-
1
:
perm
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
g_idx
=
g_idx
[
perm
]
else
:
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
if
quant_method
==
"awq"
:
repacked
=
marlin_kernels
.
awq_marlin_repack
(
qweight
,
in_features
,
out_features
,
bits
)
if
qzeros
is
not
None
:
qzeros
=
awq_to_marlin_zero_points
(
qzeros
,
in_features
//
groupsize
,
out_features
,
bits
,
)
else
:
repacked
=
marlin_kernels
.
gptq_marlin_repack
(
qweight
,
perm
,
in_features
,
out_features
,
bits
)
if
qzeros
is
None
:
qzeros
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
scales
=
permute_scales
(
scales
)
is_full_k
=
not
(
desc_act
and
sharded_infeatures
)
return
GPTQMarlinWeight
(
qweight
=
repacked
,
qzeros
=
qzeros
,
scales
=
scales
,
g_idx
=
g_idx
,
perm
=
perm
,
bits
=
bits
,
is_full_k
=
is_full_k
,
)
class
GPTQMarlinLinear
(
nn
.
Module
):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def
__init__
(
self
,
*
,
weight
:
GPTQMarlinWeight
,
bias
:
Optional
[
torch
.
Tensor
],
):
super
().
__init__
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
in_features
=
weight
.
qweight
.
shape
[
0
]
*
MARLIN_TILE_SIZE
out_features
=
weight
.
scales
.
shape
[
1
]
_check_valid_shape
(
in_features
=
in_features
,
out_features
=
out_features
)
self
.
bits
=
weight
.
bits
self
.
is_full_k
=
weight
.
is_full_k
self
.
qweight
=
weight
.
qweight
self
.
qzeros
=
weight
.
qzeros
self
.
scales
=
weight
.
scales
self
.
g_idx
=
weight
.
g_idx
self
.
perm
=
weight
.
perm
if
bias
is
not
None
:
self
.
bias
=
bias
else
:
self
.
bias
=
None
self
.
workspace
=
torch
.
zeros
(
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
weight
.
qweight
.
device
)
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
marlin_kernels
is
not
None
A_flat
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
])
C
=
marlin_kernels
.
gptq_marlin_gemm
(
A_flat
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
g_idx
,
self
.
perm
,
self
.
workspace
,
self
.
bits
,
A_flat
.
shape
[
0
],
self
.
scales
.
shape
[
1
],
A_flat
.
shape
[
1
],
self
.
is_full_k
,
self
.
qzeros
.
numel
()
>
0
,
)
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
scales
.
shape
[
1
],))
if
self
.
bias
is
not
None
:
C
+=
self
.
bias
return
C
def
awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp
=
unpack_cols
(
q_zp_packed
,
num_bits
,
size_k
,
size_n
)
# Undo interleaving (use argsort(..) to get inverse perm)
if
num_bits
==
4
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]))
elif
num_bits
==
8
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
1
,
3
]))
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_zp
=
q_zp
.
reshape
((
-
1
,
len
(
undo_interleave
)))[:,
undo_interleave
].
ravel
()
q_zp
=
q_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
marlin_zp
=
marlin_zero_points
(
q_zp
,
size_k
,
size_n
,
num_bits
)
return
marlin_zp
def
_check_valid_shape
(
in_features
:
int
,
out_features
:
int
):
if
(
in_features
%
128
!=
0
or
out_features
%
64
!=
0
)
and
(
in_features
%
64
!=
0
or
out_features
%
128
!=
0
):
raise
ValueError
(
f
"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape (
{
out_features
}
,
{
in_features
}
)."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)
server/text_generation_server/layers/marlin.py
→
server/text_generation_server/layers/marlin
/marlin
.py
View file @
93d2b9fe
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
loguru
import
logger
from
text_generation_server.layers.marlin.util
import
_check_marlin_kernels
from
text_generation_server.layers.fp8
import
fp8_quantize
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.log
import
log_once
from
text_generation_server.utils.weights
import
Weight
,
Weights
,
WeightsLoader
from
text_generation_server.utils.weights
import
Weight
,
Weights
,
WeightsLoader
try
:
try
:
...
@@ -15,17 +11,6 @@ try:
...
@@ -15,17 +11,6 @@ try:
except
ImportError
:
except
ImportError
:
marlin_kernels
=
None
marlin_kernels
=
None
try
:
major
,
_minor
=
torch
.
cuda
.
get_device_capability
()
has_sm_8_0
=
major
>=
8
except
Exception
:
has_sm_8_0
=
False
GPTQ_MARLIN_BITS
=
[
4
,
8
]
GPTQ_MARLIN_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_TILE_SIZE
=
16
class
MarlinWeightsLoader
(
WeightsLoader
):
class
MarlinWeightsLoader
(
WeightsLoader
):
"""Loader for Marlin-quantized weights."""
"""Loader for Marlin-quantized weights."""
...
@@ -168,244 +153,73 @@ class MarlinWeightsLoader(WeightsLoader):
...
@@ -168,244 +153,73 @@ class MarlinWeightsLoader(WeightsLoader):
return
weight
return
weight
def
can_use_gptq_marlin
(
*
,
bits
:
int
,
groupsize
:
int
,
quant_method
:
str
,
quantize
:
str
,
sym
:
bool
)
->
bool
:
return
(
SYSTEM
==
"cuda"
and
marlin_kernels
is
not
None
and
has_sm_8_0
and
quantize
in
{
"awq"
,
"gptq"
}
and
quant_method
in
{
"awq"
,
"gptq"
}
and
bits
in
GPTQ_MARLIN_BITS
and
groupsize
in
GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
and
(
sym
or
quant_method
==
"awq"
)
)
def
_check_marlin_kernels
():
if
not
(
SYSTEM
==
"cuda"
and
has_sm_8_0
):
raise
NotImplementedError
(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if
marlin_kernels
is
None
:
raise
NotImplementedError
(
"marlin is not installed, install it with: pip install server/marlin"
)
def
_check_valid_shape
(
in_features
:
int
,
out_features
:
int
):
if
(
in_features
%
128
!=
0
or
out_features
%
64
!=
0
)
and
(
in_features
%
64
!=
0
or
out_features
%
128
!=
0
):
raise
ValueError
(
f
"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape (
{
out_features
}
,
{
in_features
}
)."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
def
_get_perms
()
->
Tuple
[
List
[
int
],
List
[
int
]]:
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
([
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
_scale_perm
,
_scale_perm_single
=
_get_perms
()
def
permute_scales
(
scales
:
torch
.
Tensor
):
out_features
=
scales
.
shape
[
1
]
if
scales
.
shape
[
0
]
==
1
:
scales
=
scales
.
reshape
((
-
1
,
len
(
_scale_perm_single
)))[:,
_scale_perm_single
]
else
:
scales
=
scales
.
reshape
((
-
1
,
len
(
_scale_perm
)))[:,
_scale_perm
]
return
scales
.
reshape
((
-
1
,
out_features
)).
contiguous
()
@
dataclass
@
dataclass
class
GPTQ
MarlinWeight
(
Weight
):
class
MarlinWeight
(
Weight
):
"""
"""
Repacked GPTQ Marlin weights.
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
"""
qweight
:
torch
.
Tensor
B
:
torch
.
Tensor
qzeros
:
torch
.
Tensor
s
:
torch
.
Tensor
scales
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
perm
:
torch
.
Tensor
bits
:
int
is_full_k
:
bool
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
qweight
.
dtype
==
torch
.
int32
assert
self
.
B
.
dtype
==
torch
.
int32
assert
self
.
scales
.
dtype
==
torch
.
float16
assert
self
.
s
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
self
.
g_idx
.
dtype
==
torch
.
int32
assert
self
.
perm
.
dtype
==
torch
.
int32
def
get_linear
(
self
,
bias
:
torch
.
Tensor
):
def
get_linear
(
self
,
bias
:
torch
.
Tensor
):
return
GPTQMarlinLinear
(
return
MarlinLinear
(
weight
=
self
,
bias
=
bias
)
weight
=
self
,
bias
=
bias
,
)
def
repack_gptq_for_marlin
(
*
,
qweight
:
torch
.
Tensor
,
qzeros
:
Optional
[
torch
.
Tensor
],
scales
:
torch
.
Tensor
,
g_idx
:
Optional
[
torch
.
Tensor
],
bits
:
int
,
desc_act
:
bool
,
groupsize
:
int
,
quant_method
:
str
,
sym
:
bool
,
sharded_infeatures
:
bool
,
)
->
GPTQMarlinWeight
:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
if
bits
not
in
GPTQ_MARLIN_BITS
:
supported_bits
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_BITS
)
raise
RuntimeError
(
f
"Repacking
{
bits
}
-bit GPTQ weights as Marlin is not supported, must be one of:
{
supported_bits
}
"
)
if
groupsize
not
in
GPTQ_MARLIN_GROUP_SIZES
:
supported_sizes
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_GROUP_SIZES
)
raise
RuntimeError
(
f
"Repacking GPTQ weights with group size
{
groupsize
}
as Marlin is not supported, must be one of:
{
supported_sizes
}
"
)
if
not
(
sym
or
quant_method
==
"awq"
):
raise
RuntimeError
(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once
(
logger
.
info
,
f
"Converting
{
quant_method
}
model to Marlin packing format."
)
weights_per_int
=
32
//
bits
in_features
=
qweight
.
shape
[
0
]
out_features
=
qweight
.
shape
[
1
]
# AWQ uses column packing, GPTQ uses row packing
if
quant_method
==
"awq"
:
out_features
*=
weights_per_int
else
:
in_features
*=
weights_per_int
if
in_features
%
groupsize
!=
0
:
raise
ValueError
(
f
"Number of input features (
{
in_features
}
) not divisible by group size (
{
groupsize
}
)"
)
if
g_idx
is
not
None
and
desc_act
and
groupsize
!=
-
1
:
perm
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
g_idx
=
g_idx
[
perm
]
else
:
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
if
quant_method
==
"awq"
:
repacked
=
marlin_kernels
.
awq_marlin_repack
(
qweight
,
in_features
,
out_features
,
bits
)
if
qzeros
is
not
None
:
qzeros
=
awq_to_marlin_zero_points
(
qzeros
,
in_features
//
groupsize
,
out_features
,
bits
,
)
else
:
repacked
=
marlin_kernels
.
gptq_marlin_repack
(
qweight
,
perm
,
in_features
,
out_features
,
bits
)
if
qzeros
is
None
:
qzeros
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
scales
=
permute_scales
(
scales
)
is_full_k
=
not
(
desc_act
and
sharded_infeatures
)
return
GPTQMarlinWeight
(
qweight
=
repacked
,
qzeros
=
qzeros
,
scales
=
scales
,
g_idx
=
g_idx
,
perm
=
perm
,
bits
=
bits
,
is_full_k
=
is_full_k
,
)
class
GPTQMarlinLinear
(
nn
.
Module
):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def
__init__
(
class
MarlinLinear
(
nn
.
Module
):
self
,
def
__init__
(
self
,
*
,
weight
:
MarlinWeight
,
bias
:
Optional
[
torch
.
Tensor
]):
*
,
weight
:
GPTQMarlinWeight
,
bias
:
Optional
[
torch
.
Tensor
],
):
super
().
__init__
()
super
().
__init__
()
_check_marlin_kernels
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
assert
marlin_kernels
is
not
None
in_features
=
weight
.
qweight
.
shape
[
0
]
*
MARLIN_TILE_SIZE
in_features
=
weight
.
B
.
shape
[
0
]
*
MARLIN_TILE_SIZE
out_features
=
weight
.
scales
.
shape
[
1
]
out_features
=
weight
.
s
.
shape
[
1
]
_check_valid_shape
(
in_features
=
in_features
,
out_features
=
out_features
)
assert
(
in_features
%
128
==
0
),
f
"Number of input features (
{
in_features
}
) not divisable by 128"
assert
(
out_features
%
256
==
0
),
f
"Number of output features (
{
out_features
}
) not divisable by 256"
self
.
bits
=
weight
.
bits
groupsize
=
-
1
if
weight
.
s
.
shape
[
0
]
==
1
else
in_features
//
weight
.
s
.
shape
[
0
]
self
.
is_full_k
=
weight
.
is_full_k
assert
groupsize
in
{
-
1
,
128
,
},
f
"Group size must be -1 or 128, was
{
groupsize
}
"
self
.
qweight
=
weight
.
qweight
self
.
B
=
weight
.
B
self
.
qzeros
=
weight
.
qzeros
self
.
s
=
weight
.
s
self
.
scales
=
weight
.
scales
self
.
g_idx
=
weight
.
g_idx
self
.
perm
=
weight
.
perm
if
bias
is
not
None
:
if
bias
is
not
None
:
self
.
bias
=
bias
self
.
bias
=
bias
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
workspace
=
torch
.
zeros
(
self
.
workspace
=
torch
.
zeros
(
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
weight
.
qweight
.
device
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
weight
.
B
.
device
)
)
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
marlin_kernels
is
not
None
assert
marlin_kernels
is
not
None
A_flat
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
])
C
=
marlin_kernels
.
marlin_gemm
(
C
=
marlin_kernels
.
gptq_marlin_gemm
(
A
.
view
(
-
1
,
A
.
shape
[
-
1
]),
A_flat
,
self
.
B
,
self
.
qweight
,
self
.
s
,
self
.
scales
,
self
.
qzeros
,
self
.
g_idx
,
self
.
perm
,
self
.
workspace
,
self
.
workspace
,
self
.
bits
,
A
.
shape
[
0
],
A_flat
.
shape
[
0
],
self
.
s
.
shape
[
1
],
self
.
scales
.
shape
[
1
],
A
.
shape
[
1
],
A_flat
.
shape
[
1
],
self
.
is_full_k
,
self
.
qzeros
.
numel
()
>
0
,
)
)
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
s
cales
.
shape
[
1
],))
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
s
.
shape
[
1
],))
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
C
+=
self
.
bias
C
+=
self
.
bias
...
@@ -418,6 +232,7 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128
...
@@ -418,6 +232,7 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MARLIN_TILE_SIZE
=
16
@
dataclass
@
dataclass
...
@@ -456,8 +271,10 @@ class GPTQMarlin24Linear(nn.Module):
...
@@ -456,8 +271,10 @@ class GPTQMarlin24Linear(nn.Module):
_check_marlin_kernels
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
assert
marlin_kernels
is
not
None
if
weight
.
bits
not
in
GPTQ_MARLIN_BITS
:
if
weight
.
bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
supported_bits
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_BITS
)
supported_bits
=
", "
.
join
(
str
(
b
)
for
b
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
raise
RuntimeError
(
raise
RuntimeError
(
f
"
{
weight
.
bits
}
-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of:
{
supported_bits
}
"
f
"
{
weight
.
bits
}
-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of:
{
supported_bits
}
"
)
)
...
@@ -527,310 +344,3 @@ class GPTQMarlin24Linear(nn.Module):
...
@@ -527,310 +344,3 @@ class GPTQMarlin24Linear(nn.Module):
C
+=
self
.
bias
C
+=
self
.
bias
return
C
return
C
class
GPTQMarlinFP8Linear
(
nn
.
Module
):
"""
FP8 GPTQ-Marlin linear layer.
"""
def
__init__
(
self
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
None
:
super
().
__init__
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
log_once
(
logger
.
info
,
"GPU does not support FP8, using Marlin FP8 kernel"
)
scales
=
scales
.
unsqueeze
(
0
)
if
scales
.
shape
[
1
]
==
1
:
out_features
,
in_features
=
qweight
.
shape
scales
=
scales
.
repeat
(
1
,
out_features
)
qweight
,
scales
=
repack_fp8_for_marlin
(
qweight
,
scales
)
in_features
=
qweight
.
shape
[
0
]
*
MARLIN_TILE_SIZE
out_features
=
scales
.
shape
[
1
]
_check_valid_shape
(
in_features
=
in_features
,
out_features
=
out_features
)
self
.
qweight
=
qweight
self
.
scales
=
scales
self
.
bias
=
bias
if
bias
is
not
None
else
None
self
.
workspace
=
torch
.
zeros
(
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
@
classmethod
def
from_unquant
(
cls
,
weight
,
bias
,
dtype
):
qweight
,
scales
=
fp8_quantize
(
weight
)
return
cls
(
qweight
=
qweight
,
scales
=
scales
.
to
(
dtype
),
bias
=
bias
)
@
classmethod
def
from_fp8
(
cls
,
weight
,
scale
,
_input_scale
,
bias
,
dtype
):
return
cls
(
qweight
=
weight
,
scales
=
scale
.
to
(
dtype
),
bias
=
bias
)
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
marlin_kernels
is
not
None
A_flat
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
])
C
=
marlin_kernels
.
fp8_marlin_gemm
(
A_flat
,
self
.
qweight
,
self
.
scales
,
self
.
workspace
,
8
,
A_flat
.
shape
[
0
],
self
.
scales
.
shape
[
1
],
A_flat
.
shape
[
1
],
)
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
scales
.
shape
[
1
],))
if
self
.
bias
is
not
None
:
C
+=
self
.
bias
return
C
def
pack_fp8_as_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
if
fp8_tensor
.
shape
[
0
]
%
4
!=
0
:
raise
ValueError
(
f
"Leading tensor dimension is not divisable by 4:
{
fp8_tensor
.
shape
[
0
]
}
"
)
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
torch
.
zeros
(
fp8_tensor
.
shape
[
0
]
//
4
,
fp8_tensor
.
shape
[
1
],
dtype
=
torch
.
int32
,
device
=
fp8_tensor
.
device
,
)
for
i
in
range
(
4
):
packed
.
bitwise_or_
(
byte_tensor
[:,
i
].
to
(
torch
.
int32
)
<<
i
*
8
)
return
packed
def
repack_fp8_for_marlin
(
weight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features
,
in_features
=
weight
.
shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight
=
pack_fp8_as_int32
(
weight
.
t
())
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
qweight
.
device
)
repacked
=
marlin_kernels
.
gptq_marlin_repack
(
qweight
,
perm
,
in_features
,
out_features
,
8
)
scales
=
permute_scales
(
scales
)
return
repacked
,
scales
@
dataclass
class
MarlinWeight
(
Weight
):
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
B
:
torch
.
Tensor
s
:
torch
.
Tensor
def
__post_init__
(
self
):
assert
self
.
B
.
dtype
==
torch
.
int32
assert
self
.
s
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
def
get_linear
(
self
,
bias
:
torch
.
Tensor
):
return
MarlinLinear
(
weight
=
self
,
bias
=
bias
)
class
MarlinLinear
(
nn
.
Module
):
def
__init__
(
self
,
*
,
weight
:
MarlinWeight
,
bias
:
Optional
[
torch
.
Tensor
]):
super
().
__init__
()
_check_marlin_kernels
()
assert
marlin_kernels
is
not
None
in_features
=
weight
.
B
.
shape
[
0
]
*
MARLIN_TILE_SIZE
out_features
=
weight
.
s
.
shape
[
1
]
assert
(
in_features
%
128
==
0
),
f
"Number of input features (
{
in_features
}
) not divisable by 128"
assert
(
out_features
%
256
==
0
),
f
"Number of output features (
{
out_features
}
) not divisable by 256"
groupsize
=
-
1
if
weight
.
s
.
shape
[
0
]
==
1
else
in_features
//
weight
.
s
.
shape
[
0
]
assert
groupsize
in
{
-
1
,
128
,
},
f
"Group size must be -1 or 128, was
{
groupsize
}
"
self
.
B
=
weight
.
B
self
.
s
=
weight
.
s
if
bias
is
not
None
:
self
.
bias
=
bias
else
:
self
.
bias
=
None
self
.
workspace
=
torch
.
zeros
(
out_features
//
64
*
16
,
dtype
=
torch
.
int
,
device
=
weight
.
B
.
device
)
def
forward
(
self
,
A
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
marlin_kernels
is
not
None
C
=
marlin_kernels
.
marlin_gemm
(
A
.
view
(
-
1
,
A
.
shape
[
-
1
]),
self
.
B
,
self
.
s
,
self
.
workspace
,
A
.
shape
[
0
],
self
.
s
.
shape
[
1
],
A
.
shape
[
1
],
)
C
=
C
.
reshape
(
A
.
shape
[:
-
1
]
+
(
self
.
s
.
shape
[
1
],))
if
self
.
bias
is
not
None
:
C
+=
self
.
bias
return
C
# Functions below are from vLLM
def
get_pack_factor
(
bits
:
int
)
->
int
:
if
32
%
bits
!=
0
:
raise
ValueError
(
f
"Cannot
{
bits
}
bit values into uint32"
)
return
32
//
bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp
=
zp
.
reshape
((
-
1
,
len
(
_scale_perm
)))[:,
_scale_perm
]
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
return
zp
def
awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp
=
unpack_cols
(
q_zp_packed
,
num_bits
,
size_k
,
size_n
)
# Undo interleaving (use argsort(..) to get inverse perm)
if
num_bits
==
4
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]))
elif
num_bits
==
8
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
1
,
3
]))
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_zp
=
q_zp
.
reshape
((
-
1
,
len
(
undo_interleave
)))[:,
undo_interleave
].
ravel
()
q_zp
=
q_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
marlin_zp
=
marlin_zero_points
(
q_zp
,
size_k
,
size_n
,
num_bits
)
return
marlin_zp
server/text_generation_server/layers/marlin/util.py
0 → 100644
View file @
93d2b9fe
import
functools
from
typing
import
List
,
Tuple
import
numpy
import
torch
from
text_generation_server.utils.import_utils
import
SYSTEM
try
:
import
marlin_kernels
except
ImportError
:
marlin_kernels
=
None
try
:
major
,
_minor
=
torch
.
cuda
.
get_device_capability
()
has_sm_8_0
=
major
>=
8
except
Exception
:
has_sm_8_0
=
False
def
_check_marlin_kernels
():
if
not
(
SYSTEM
==
"cuda"
and
has_sm_8_0
):
raise
NotImplementedError
(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if
marlin_kernels
is
None
:
raise
NotImplementedError
(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@
functools
.
cache
def
get_perms
()
->
Tuple
[
List
[
int
],
List
[
int
]]:
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
([
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
permute_scales
(
scales
:
torch
.
Tensor
):
scale_perm
,
scale_perm_single
=
get_perms
()
out_features
=
scales
.
shape
[
1
]
if
scales
.
shape
[
0
]
==
1
:
scales
=
scales
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
else
:
scales
=
scales
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
return
scales
.
reshape
((
-
1
,
out_features
)).
contiguous
()
# Functions below are from vLLM
def
get_pack_factor
(
bits
:
int
)
->
int
:
if
32
%
bits
!=
0
:
raise
ValueError
(
f
"Cannot
{
bits
}
bit values into uint32"
)
return
32
//
bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
scale_perm
,
_
=
get_perms
()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
return
zp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment