Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
babf52da
Unverified
Commit
babf52da
authored
Jul 13, 2024
by
Robert Shaw
Committed by
GitHub
Jul 13, 2024
Browse files
[ Misc ] More Cleanup of Marlin (#6359)
Co-authored-by:
Robert Shaw
<
rshaw@neuralmagic.com
>
parent
9da4aad4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
48 deletions
+44
-48
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+31
-47
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+12
-0
No files found.
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
View file @
babf52da
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# We use this for fp8, which HF does not support.
# We use this for fp8, which HF does not support.
#
#
# Make sure you have lm-eval-harness installed:
# Make sure you have lm-eval-harness installed:
# pip install lm-eval==0.4.
2
# pip install lm-eval==0.4.
3
usage
()
{
usage
()
{
echo
``
echo
``
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
babf52da
...
@@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
apply_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_permute_scales
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
...
@@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
)
->
None
:
)
->
None
:
del
output_size
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
...
@@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
input_size
=
input_size
,
input_size
=
input_size
,
group_size
=
group_size
)
group_size
=
group_size
)
# Detect sharding of scales/zp
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
# By default, no sharding over "input dim"
self
.
quant_config
.
group_size
,
scales_and_zp_size
=
input_size
//
group_size
is_row_parallel
):
scales_and_zp_input_dim
=
None
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
if
self
.
quant_config
.
desc_act
:
scales_and_zp_input_dim
=
None
# Act-order case
scales_and_zp_size
=
input_size
//
group_size
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
else
:
# No act-order case
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
# K is always full due to full alignment with
scales_and_zp_input_dim
=
0
# group-size and shard of scales/zp
scales_and_zp_size
=
input_size_per_partition
//
group_size
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
# Quantized weights
qweight
=
Parameter
(
qweight
=
Parameter
(
...
@@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
is_k_full
=
marlin_is_k_full
(
self
.
quant_config
.
desc_act
,
is_row_parallel
)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
# Here, we handle the repacking, including the activation reordering case.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
layer
.
output_size_per_partition
,
device
)
...
@@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
return
apply_marlin_linear
(
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,
)
input
=
x
,
weight
=
layer
.
qweight
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight_scale
=
layer
.
scales
,
layer
.
qweight
,
g_idx
=
layer
.
g_idx
,
layer
.
scales
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx
=
layer
.
g_idx
,
workspace
=
layer
.
workspace
,
perm
=
layer
.
g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
weight_bits
,
workspace
=
layer
.
workspace
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
size_m
=
reshaped_x
.
shape
[
0
],
is_k_full
=
layer
.
is_k_full
,
size_n
=
layer
.
output_size_per_partition
,
bias
=
bias
)
size_k
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
babf52da
...
@@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
...
@@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad
=
False
)
requires_grad
=
False
)
def
marlin_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
def
marlin_repeat_scales_on_all_ranks
(
act_order
:
bool
,
group_size
:
int
,
is_row_parallel
:
bool
)
->
bool
:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise
=
group_size
==
-
1
return
act_order
or
(
is_channelwise
and
is_row_parallel
)
def
marlin_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
def
marlin_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
requires_grad
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment