Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b675069d
Unverified
Commit
b675069d
authored
Jul 11, 2024
by
Robert Shaw
Committed by
GitHub
Jul 11, 2024
Browse files
[ Misc ] Refactor Marlin Python Utilities (#6082)
Co-authored-by:
Robert Shaw
<
rshaw@neuralmagic.com
>
parent
55f692b4
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
704 additions
and
742 deletions
+704
-742
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+6
-4
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+25
-21
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+14
-9
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+69
-82
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+67
-196
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
...del_executor/layers/quantization/utils/marlin_24_perms.py
+0
-60
vllm/model_executor/layers/quantization/utils/marlin_perms.py
.../model_executor/layers/quantization/utils/marlin_perms.py
+0
-60
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+133
-306
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+109
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+120
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
...xecutor/layers/quantization/utils/marlin_utils_test_24.py
+160
-3
No files found.
benchmarks/kernels/benchmark_marlin.py
View file @
b675069d
...
...
@@ -5,14 +5,16 @@ import torch.utils.benchmark as benchmark
from
benchmark_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
marlin_24_quantize
,
marlin_quantize
)
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
from
vllm.utils
import
FlexibleArgumentParser
...
...
tests/kernels/test_marlin_gemm.py
View file @
b675069d
...
...
@@ -5,19 +5,21 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
marlin_quantize
,
marlin_weights
,
pack_fp8_to_int32
)
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
...
...
@@ -42,11 +44,16 @@ MNK_FACTORS = [
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
...
...
@@ -93,8 +100,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Pack to Marlin format
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
]
)
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
...
...
@@ -109,7 +116,7 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
...
...
@@ -174,7 +181,7 @@ def test_marlin_gemm(
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_
marlin_supported
(
),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_24_N_CHUNKS
)
...
...
@@ -222,7 +229,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_
marlin
_supported
(),
@
pytest
.
mark
.
skipif
(
not
is_
quant_method
_supported
(
"fp8"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
...
...
@@ -268,13 +275,10 @@ def test_fp8_marlin_gemm(
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
group_size
=-
1
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
...
...
tests/quantization/test_compressed_tensors.py
View file @
b675069d
...
...
@@ -6,7 +6,6 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
import
pytest
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
...
...
@@ -57,12 +56,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
def
test_compressed_tensors_no_enforce_eager
(
vllm_runner
):
model_path
=
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with
vllm_runner
(
model_path
)
as
llm
:
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
...
...
@@ -84,13 +85,16 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
@
pytest
.
mark
.
parametrize
(
"wNa16_args"
,
[(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
)])
def
test_compressed_tensors_w
4
a16
(
vllm_runner
,
wNa16_args
):
def
test_compressed_tensors_w
N
a16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
...
...
@@ -101,12 +105,15 @@ def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
group_size
==
group
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
weight_packed
.
pack_factor
==
pack_factor
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
...
...
@@ -120,8 +127,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
...
...
@@ -142,6 +148,5 @@ def test_compressed_tensors_fp8(vllm_runner):
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
b675069d
...
...
@@ -6,9 +6,10 @@ from torch.nn import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsWNA16"
]
...
...
@@ -22,29 +23,40 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
num_bits
=
num_bits
self
.
pack_factor
=
32
//
self
.
num_bits
self
.
strategy
=
strategy
self
.
group_size
=
group_size
if
self
.
strategy
==
"group"
and
self
.
group_size
is
None
:
self
.
group_size
:
int
if
group_size
is
None
:
if
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"group_size must be given when using strategy group"
)
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
self
.
group_size
=
-
1
else
:
self
.
group_size
=
group_size
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
num_bits
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
pack_factor
=
32
//
self
.
num_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
self
.
group_size
is
not
None
:
group_size
=
self
.
group_size
else
:
group_size
=
input_size
# If group_size is -1, we are in channelwise case.
group_size
=
input_size
if
self
.
group_size
==
-
1
else
self
.
group_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
weight_scale_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
...
...
@@ -57,7 +69,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
pack_factor
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
...
...
@@ -68,7 +80,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
"input_dim"
:
1
,
"output_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
pack_factor
,
"pack_factor"
:
self
.
pack_factor
,
"weight_loader"
:
weight_loader
})
layer
.
register_parameter
(
"weight_packed"
,
weight
)
...
...
@@ -103,73 +115,48 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
layer
.
is_k_full
=
True
layer
.
group_size
=
group_size
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
weight_packed
.
device
cur_device
=
layer
.
weight_packed
.
device
# Allocate marlin workspace.
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# Repack weights
# Repack weights
from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
(),
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
num_bits
)
replace_tensor
(
"weight_packed"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
layer
.
weight_packed
.
t
().
contiguous
(),
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
num_bits
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
scales_size_k
,
scales_size_n
,
layer
.
group_size
,
self
.
num_bits
)
replace_tensor
(
"weight_scale"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
weight_packed
,
layer
.
weight_scale
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
num_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
)
return
output
.
reshape
(
out_shape
)
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
return
apply_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
num_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
b675069d
...
...
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
b675069d
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -12,46 +10,14 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_K
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
GPTQ_MARLIN_SUPPORTED_SYM
,
GPTQ_MARLIN_TILE
)
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
...
...
@@ -63,33 +29,16 @@ class GPTQMarlinConfig(QuantizationConfig):
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
lm_head_quantized
=
lm_head_quantized
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
is_sym
=
self
.
is_sym
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
...
...
@@ -168,21 +117,10 @@ class GPTQMarlinConfig(QuantizationConfig):
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
return
check_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
is_sym
=
sym
,
min_capability
=
cls
.
get_min_capability
())
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
...
@@ -206,6 +144,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
...
...
@@ -213,31 +152,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
# Detect sharding of scales/zp
...
...
@@ -303,11 +222,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
...
...
@@ -347,114 +261,71 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Process act_order
# Handle sorting for activation reordering if needed.
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"g_idx"
,
g_idx
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# Repack weights
# Repack weights
from autogptq format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
g_idx
=
layer
.
g_idx
,
perm
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
deleted
100644 → 0
View file @
55f692b4
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_24_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_24_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
vllm/model_executor/layers/quantization/utils/marlin_perms.py
deleted
100644 → 0
View file @
55f692b4
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
b675069d
"""This file is used for /tests and /benchmarks"""
import
random
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Tuple
import
numpy
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.format_24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
vllm.model_executor.layers.quantization.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
...
...
@@ -25,135 +13,110 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
def
is_marlin_supported
():
capability
=
current_platform
.
get_device_capability
()
return
capability
[
0
]
>=
8
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
workspace
=
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
size_n
,
size_k
=
size_k
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
num_bits
=
8
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
scale_perm
=
marlin_scale_perm
[
num_bits
],
scale_perm_single
=
marlin_scale_perm_single
[
num_bits
])
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER
=
[
-
1
]
def
check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
# If the capability of the device is too low, cannot convert.
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
return
False
return
(
device_capability
>=
min_capability
and
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
is_sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
def
verify_marlin_supported
(
num_bits
:
int
,
group_size
:
Optional
[
int
],
is_sym
:
bool
)
->
None
:
if
num_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
num_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
(
group_size
is
None
or
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
):
raise
ValueError
(
f
"Marlin does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym = is_sym. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
)
->
None
:
# Validate output_size_per_partition
if
output_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_N
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
GPTQ_MARLIN_MIN_THREAD_N
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
GPTQ_MARLIN_MIN_THREAD_K
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
return
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
def
marlin_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
def
marlin_sort_g_idx
(
g_idx
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
def
get_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
return
q_packed
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
scale_perm
,
scale_perm_single
=
get_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
...
...
@@ -163,180 +126,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
new_t
:
torch
.
Tensor
)
->
None
:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
def
apply_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight_scale
,
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
)
# Pack 4 uint8 values into one int32
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
0 → 100644
View file @
b675069d
from
typing
import
Optional
import
torch
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
.marlin_utils
import
marlin_make_workspace
,
marlin_permute_scales
def
is_fp8_marlin_supported
():
capability
=
current_platform
.
get_device_capability
()
return
capability
[
0
]
>=
8
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
workspace
=
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
size_n
,
size_k
=
size_k
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
device
=
layer
.
weight
.
device
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace
(
part_size_n
,
device
)
# WEIGHT
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
pack_fp8_to_int32
(
layer
.
weight
),
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
0 → 100644
View file @
b675069d
"""Utility functions used for tests and benchmarks"""
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils
import
GPTQ_MARLIN_TILE
,
marlin_permute_scales
from
.quant_utils
import
get_pack_factor
,
quantize_weights
,
sort_weights
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/
forma
t_24.py
→
vllm/model_executor/layers/quantization/utils/
marlin_utils_tes
t_24.py
View file @
b675069d
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
"""Utility functions used for tests and benchmarks"""
import
random
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils_test
import
marlin_weights
from
.quant_utils
import
quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
...
...
@@ -306,3 +311,155 @@ def mask_creator(tensor):
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
get_scale_perms_24
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
scale_perm
,
scale_perm_single
def
get_weight_perm_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_permute_scales_24
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
scale_perm
,
scale_perm_single
=
get_scale_perms_24
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
weight_perm
=
get_weight_perm_24
(
num_bits
)
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
weight_perm
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment