Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
90bf52c7
"tests/test_config.py" did not exist on "a2090375ca8e54e7494d73ba31ce91c50659e556"
Commit
90bf52c7
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Remove cosine loss
parent
9848b6a4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
9 deletions
+5
-9
awq/models/base.py
awq/models/base.py
+2
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+3
-7
No files found.
awq/models/base.py
View file @
90bf52c7
...
...
@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
,
loss_objective
=
'mse'
):
split
=
"train"
,
text_column
=
"text"
):
self
.
quant_config
=
quant_config
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
quant_config
[
"w_bit"
],
quant_config
[
"q_group_size"
],
quant_config
[
"version"
],
calib_data
,
split
,
text_column
,
loss_objective
quant_config
[
"version"
],
calib_data
,
split
,
text_column
)
quantizer
.
quantize
()
self
.
is_quantized
=
True
...
...
awq/quantize/quantizer.py
View file @
90bf52c7
...
...
@@ -13,7 +13,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
,
loss_objective
=
'mse'
)
->
None
:
calib_data
,
split
,
text_column
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -24,7 +24,6 @@ class AwqQuantizer:
self
.
split
=
split
self
.
text_column
=
text_column
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
self
.
loss_objective
=
loss_objective
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
...
...
@@ -191,12 +190,9 @@ class AwqQuantizer:
if
isinstance
(
int_w_output
,
tuple
):
int_w_output
=
int_w_output
[
0
]
if
self
.
loss_objective
==
'mse'
:
#
(L2 norm)
# compute mean squared error
(L2 norm)
loss
=
(
fp16_output
-
int_w_output
).
float
().
pow
(
2
).
mean
().
item
()
# NOTE: float prevents overflow
elif
self
.
loss_objective
==
'cosine'
:
loss
=
-
nn
.
functional
.
cosine_similarity
(
fp16_output
,
int_w_output
).
mean
().
item
()
history
.
append
(
loss
)
if
loss
<
best_error
:
best_error
=
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment