Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd9857f5
Unverified
Commit
dd9857f5
authored
Aug 26, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 26, 2024
Browse files
[Misc] Update `gptq_marlin_24` to use vLLMParameters (#7762)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
66530409
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
54 deletions
+50
-54
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+49
-53
No files found.
vllm/model_executor/layers/linear.py
View file @
dd9857f5
...
...
@@ -23,7 +23,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
]
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
dd9857f5
...
...
@@ -8,7 +8,10 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -149,7 +152,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
...
...
@@ -187,87 +190,80 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
//
2
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
# Meta
meta
=
Parameter
(
torch
.
empty
(
meta
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
device
=
"cuda"
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
1
,
marlin_tile_size
=
2
,
weight_loader
=
weight_loader
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
scales
=
Parameter
(
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
"weight_loader"
:
weight_loader
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B_24"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"B_meta"
,
meta
)
set_weight_attrs
(
meta
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B_24
=
Parameter
(
layer
.
B_24
.
data
,
requires_grad
=
False
)
layer
.
s
=
Parameter
(
layer
.
s
.
data
,
requires_grad
=
False
)
layer
.
B_meta
=
Parameter
(
layer
.
B_meta
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment