Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
1712ce21
"torchvision/vscode:/vscode.git/clone" did not exist on "7f42437989b5ba564a2fcd3fb8f60eba5f771bdf"
Commit
1712ce21
authored
Sep 05, 2023
by
Erwan BOEHM
Browse files
allow user to use custom calibration data for quantization
parent
abdc726c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
5 deletions
+13
-5
awq/models/base.py
awq/models/base.py
+3
-2
awq/utils/calib_data.py
awq/utils/calib_data.py
+10
-3
No files found.
awq/models/base.py
View file @
1712ce21
import
os
import
os
import
gc
import
gc
import
json
import
json
from
typing
import
List
,
Union
import
torch
import
torch
import
functools
import
functools
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -39,7 +40,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -39,7 +40,7 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
calib_data
=
"pileval"
):
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
if
run_search
:
if
run_search
:
...
@@ -95,7 +96,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -95,7 +96,7 @@ class BaseAWQForCausalLM(nn.Module):
gc
.
collect
()
gc
.
collect
()
def
_awq_search
(
self
,
tokenizer
,
quant_config
,
n_samples
=
128
,
seqlen
=
512
,
def
_awq_search
(
self
,
tokenizer
,
quant_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
auto_scale
=
True
,
mse_range
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
):
layers
=
self
.
get_model_layers
(
self
.
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
samples
=
get_calib_dataset
(
...
...
awq/utils/calib_data.py
View file @
1712ce21
from
typing
import
List
,
Union
import
torch
import
torch
import
logging
import
logging
from
datasets
import
load_dataset
from
datasets
import
load_dataset
def
get_calib_dataset
(
data
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
):
def
get_calib_dataset
(
data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
):
if
isinstance
(
data
,
str
):
if
data
==
"pileval"
:
if
data
==
"pileval"
:
dataset
=
load_dataset
(
"mit-han-lab/pile-val-backup"
,
split
=
"validation"
)
dataset
=
load_dataset
(
"mit-han-lab/pile-val-backup"
,
split
=
"validation"
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
elif
isinstance
(
data
,
list
):
dataset
=
[{
"text"
:
text
}
for
text
in
data
]
else
:
raise
NotImplementedError
dataset
=
dataset
.
shuffle
(
seed
=
42
)
dataset
=
dataset
.
shuffle
(
seed
=
42
)
samples
=
[]
samples
=
[]
n_run
=
0
n_run
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment