Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d1112e1c
Unverified
Commit
d1112e1c
authored
Nov 28, 2023
by
Casper
Committed by
GitHub
Nov 28, 2023
Browse files
New scaling to improve perplexity (#216)
parent
63d2aaec
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
19 deletions
+87
-19
awq/models/base.py
awq/models/base.py
+2
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+6
-2
awq/utils/eval_utils.py
awq/utils/eval_utils.py
+49
-0
examples/eval.py
examples/eval.py
+30
-15
No files found.
awq/models/base.py
View file @
d1112e1c
...
...
@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
split
=
"train"
,
text_column
=
"text"
,
duo_scaling
=
True
):
self
.
quant_config
:
AwqConfig
=
AwqConfig
.
from_dict
(
quant_config
)
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
self
.
quant_config
.
w_bit
,
self
.
quant_config
.
q_group_size
,
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
)
quantizer
.
quantize
()
self
.
is_quantized
=
True
...
...
awq/quantize/quantizer.py
View file @
d1112e1c
...
...
@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
calib_data
,
split
,
text_column
,
duo_scaling
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -24,6 +24,7 @@ class AwqQuantizer:
self
.
calib_data
=
calib_data
self
.
split
=
split
self
.
text_column
=
text_column
self
.
duo_scaling
=
duo_scaling
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
...
...
@@ -197,7 +198,10 @@ class AwqQuantizer:
ratio
=
ratio
/
n_grid
# NOTE: s^-1 * x is fused here, according to paper
if
self
.
duo_scaling
:
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
else
:
scales
=
x_max
.
pow
(
ratio
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales_view
=
scales
.
view
(
1
,
-
1
).
to
(
device
)
...
...
awq/utils/eval_utils.py
0 → 100644
View file @
d1112e1c
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
datasets
import
load_dataset
def
evaluate_perplexity
(
model
,
tokenizer
):
def
_perplexity
(
nlls
,
n_samples
,
seqlen
):
return
torch
.
exp
(
torch
.
stack
(
nlls
).
sum
()
/
(
n_samples
*
seqlen
))
# load and prepare dataset
data
=
load_dataset
(
'wikitext'
,
'wikitext-2-raw-v1'
,
split
=
'test'
)
data
=
tokenizer
(
"
\n\n
"
.
join
(
data
[
'text'
]),
return_tensors
=
'pt'
)
data
=
data
.
input_ids
.
to
(
model
.
device
)
seqlen
=
2048
model
=
model
.
eval
()
n_samples
=
data
.
numel
()
//
seqlen
nlls
=
[]
with
tqdm
(
range
(
n_samples
),
desc
=
"Perplexity -"
)
as
progress_bar
:
for
i
in
progress_bar
:
start_index
=
(
i
*
seqlen
)
end_index
=
((
i
+
1
)
*
seqlen
)
batch
=
data
[:,
start_index
:
end_index
].
to
(
model
.
device
)
with
torch
.
no_grad
():
logits
=
model
(
batch
).
logits
shift_logits
=
logits
[:,
:
-
1
,
:].
contiguous
().
float
()
shift_labels
=
data
[:,
start_index
:
end_index
][:,
1
:]
loss_fct
=
nn
.
CrossEntropyLoss
()
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
neg_log_likelihood
=
loss
.
float
()
*
seqlen
nlls
.
append
(
neg_log_likelihood
)
curr_ppl
=
_perplexity
(
nlls
,
i
+
1
,
seqlen
)
progress_bar
.
set_description
(
f
"Perplexity
{
curr_ppl
:.
3
f
}
"
)
ppl
=
_perplexity
(
nlls
,
n_samples
,
seqlen
)
return
ppl
.
item
()
if
__name__
==
'__main__'
:
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
model_path
=
'mistralai/Mistral-7B-Instruct-v0.1'
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
device_map
=
"auto"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
evaluate_perplexity
(
model
,
tokenizer
)
examples/eval.py
View file @
d1112e1c
...
...
@@ -3,26 +3,35 @@ from lm_eval import evaluator
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.eval_utils
import
evaluate_perplexity
def
run_eval
(
model_path
,
quant_file
,
device
,
tasks
,
task_batch_size
,
task_n_shot
,
task_use_pretrained
):
def
run_eval
(
model_path
,
quant_file
,
device
,
tasks
,
task_batch_size
,
task_n_shot
,
task_use_pretrained
,
pretrained_safetensors
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if
task_use_pretrained
:
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
safetensors
=
pretrained_safetensors
)
else
:
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Load adapter
tasks
=
tasks
.
split
(
','
)
if
len
(
tasks
)
==
1
and
tasks
[
0
]
==
'wikitext'
:
evaluate_perplexity
(
model
.
model
,
tokenizer
)
else
:
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
task_batch_size
)
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
tasks
.
split
(
','
)
,
tasks
=
tasks
,
batch_size
=
task_batch_size
,
no_cache
=
True
,
num_fewshot
=
task_n_shot
,
...
...
@@ -45,6 +54,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda:0'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
"--use_pretrained"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Pass '--use_pretrained' to use a pretrained model running FP16"
)
parser
.
add_argument
(
"--pretrained_safetensors"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Load safetensors for FP16 model"
)
parser
.
add_argument
(
'--tasks'
,
type
=
str
,
default
=
'wikitext'
,
help
=
'Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md'
)
...
...
@@ -52,5 +63,9 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--n_shot'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
batch_size
,
args
.
n_shot
,
args
.
use_pretrained
)
\ No newline at end of file
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
batch_size
,
args
.
n_shot
,
args
.
use_pretrained
,
args
.
pretrained_safetensors
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment