Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
29ee66d9
Unverified
Commit
29ee66d9
authored
Feb 03, 2024
by
Casper
Committed by
GitHub
Feb 03, 2024
Browse files
PEFT compatible GEMM (#324)
parent
ebe8fc3f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
181 additions
and
27 deletions
+181
-27
awq/modules/linear/gemm.py
awq/modules/linear/gemm.py
+105
-27
examples/awq_train.py
examples/awq_train.py
+76
-0
No files found.
awq/modules/linear/gemm.py
View file @
29ee66d9
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
awq.utils.utils
import
get_best_device
from
awq.utils.utils
import
get_best_device
from
awq.utils.packing_utils
import
dequantize_gemm
from
awq.utils.packing_utils
import
dequantize_gemm
...
@@ -10,9 +11,94 @@ try:
...
@@ -10,9 +11,94 @@ try:
except
:
except
:
AWQ_INSTALLED
=
False
AWQ_INSTALLED
=
False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class
WQLinearMMFunction
(
Function
):
@
staticmethod
# ctx is the first argument to forward
def
forward
(
ctx
,
x
,
qweight
,
qzeros
,
scales
,
w_bit
=
4
,
group_size
=
128
,
bias
=
None
,
out_features
=
0
):
# The forward pass can use ctx.
ctx
.
save_for_backward
(
x
,
qweight
,
qzeros
,
scales
,
bias
)
ctx
.
out_features
=
out_features
out_shape
=
x
.
shape
[:
-
1
]
+
(
out_features
,
)
x
=
x
.
to
(
torch
.
float16
)
if
AWQ_INSTALLED
:
FP16_MATMUL_HEURISTIC_CONDITION
=
x
.
shape
[
0
]
*
x
.
shape
[
1
]
>=
1024
if
FP16_MATMUL_HEURISTIC_CONDITION
:
out
=
awq_ext
.
dequantize_weights_cuda
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
,
False
)
out
=
torch
.
matmul
(
x
,
out
)
else
:
out
=
awq_ext
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
qweight
,
scales
,
qzeros
,
8
)
else
:
out
=
dequantize_gemm
(
qweight
,
qzeros
,
scales
,
w_bit
,
group_size
)
out
=
torch
.
matmul
(
x
,
out
)
out
=
out
+
bias
if
bias
is
not
None
else
out
out
=
out
.
reshape
(
out_shape
)
# always want 3D tensor if tensor is 2D
if
len
(
out
.
shape
)
==
2
:
out
=
out
.
unsqueeze
(
0
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
qweight
,
qzeros
,
scales
,
bias
=
ctx
.
saved_tensors
weights
=
awq_ext
.
dequantize_weights_cuda
(
qweight
,
scales
,
qzeros
,
1
,
0
,
0
,
False
)
if
ctx
.
needs_input_grad
[
0
]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input
=
grad_output
.
squeeze
(
0
).
mm
(
weights
.
transpose
(
0
,
1
)
).
unsqueeze
(
0
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
WQLinear_GEMM
(
nn
.
Module
):
class
WQLinear_GEMM
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
,
training
=
False
):
super
().
__init__
()
super
().
__init__
()
if
w_bit
not
in
[
4
]:
if
w_bit
not
in
[
4
]:
...
@@ -22,6 +108,7 @@ class WQLinear_GEMM(nn.Module):
...
@@ -22,6 +108,7 @@ class WQLinear_GEMM(nn.Module):
self
.
out_features
=
out_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
self
.
training
=
training
# quick sanity check (make sure aligment)
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
self
.
in_features
%
self
.
group_size
==
0
...
@@ -145,7 +232,6 @@ class WQLinear_GEMM(nn.Module):
...
@@ -145,7 +232,6 @@ class WQLinear_GEMM(nn.Module):
return
awq_linear
return
awq_linear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,)
...
@@ -153,37 +239,29 @@ class WQLinear_GEMM(nn.Module):
...
@@ -153,37 +239,29 @@ class WQLinear_GEMM(nn.Module):
if
input_dtype
!=
torch
.
float16
:
if
input_dtype
!=
torch
.
float16
:
x
=
x
.
half
()
x
=
x
.
half
()
if
AWQ_INSTALLED
:
if
self
.
training
:
FP16_MATMUL_HEURISTIC_CONDITION
=
x
.
shape
[
0
]
*
x
.
shape
[
1
]
>=
1024
out
=
WQLinearMMFunction
.
apply
(
x
,
if
FP16_MATMUL_HEURISTIC_CONDITION
:
out
=
awq_ext
.
dequantize_weights_cuda
(
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
0
,
0
,
0
,
False
,
)
out
=
torch
.
matmul
(
x
,
out
)
else
:
out
=
awq_ext
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
,
)
else
:
out
=
dequantize_gemm
(
self
.
qweight
,
self
.
qweight
,
self
.
qzeros
,
self
.
qzeros
,
self
.
scales
,
self
.
scales
,
self
.
w_bit
,
self
.
w_bit
,
self
.
group_size
,
self
.
group_size
,
self
.
bias
,
self
.
out_features
,
)
)
out
=
torch
.
matmul
(
x
,
out
)
else
:
with
torch
.
no_grad
():
out
=
WQLinearMMFunction
.
apply
(
x
,
self
.
qweight
,
self
.
qzeros
,
self
.
scales
,
self
.
w_bit
,
self
.
group_size
,
self
.
bias
,
self
.
out_features
,
)
if
input_dtype
!=
torch
.
float16
:
if
input_dtype
!=
torch
.
float16
:
out
=
out
.
to
(
dtype
=
input_dtype
)
out
=
out
.
to
(
dtype
=
input_dtype
)
...
...
examples/awq_train.py
0 → 100644
View file @
29ee66d9
import
datasets
from
awq
import
AutoAWQForCausalLM
from
transformers
import
(
AutoTokenizer
,
TrainingArguments
,
Trainer
,
DataCollatorForLanguageModeling
)
from
peft
import
get_peft_model
,
LoraConfig
,
TaskType
def
prepare_split
(
tokenizer
):
data
=
datasets
.
load_dataset
(
"mhenrichsen/alpaca_2k_test"
,
split
=
"train"
)
prompt_template
=
"<s>[INST] {system} {prompt} [/INST] {output}</s>"
def
format_prompt
(
x
):
return
prompt_template
.
format
(
system
=
""
,
prompt
=
x
[
"instruction"
],
output
=
x
[
"output"
]
)
data
=
data
.
map
(
lambda
x
:
{
"text"
:
format_prompt
(
x
)},
).
select_columns
([
"text"
])
data
=
data
.
map
(
lambda
x
:
tokenizer
(
x
[
"text"
]),
batched
=
True
)
return
data
model_path
=
"ybelkada/opt-125m-awq"
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
fuse_layers
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
# Prepare data
data_train
=
prepare_split
(
tokenizer
)
# Config Lora
lora_config
=
LoraConfig
(
r
=
4
,
lora_alpha
=
8
,
lora_dropout
=
0.5
,
bias
=
"none"
,
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
)
model
=
get_peft_model
(
model
.
model
,
lora_config
)
model
.
print_trainable_parameters
()
training_arguments
=
TrainingArguments
(
output_dir
=
"./output"
,
per_device_train_batch_size
=
1
,
optim
=
"adamw_torch"
,
num_train_epochs
=
1
,
learning_rate
=
1e-4
,
# fp16=True,
evaluation_strategy
=
"no"
,
save_strategy
=
"epoch"
,
save_steps
=
100
,
logging_steps
=
50
,
eval_steps
=
None
,
load_best_model_at_end
=
False
)
trainer
=
Trainer
(
model
=
model
,
train_dataset
=
data_train
,
args
=
training_arguments
,
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
,
mlm
=
False
),
)
trainer
.
train
()
trainer
.
save_model
(
"output"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment