Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
a3db8099
Unverified
Commit
a3db8099
authored
Jan 07, 2024
by
Casper
Committed by
GitHub
Jan 07, 2024
Browse files
GGUF compatible quantization (2, 3, 4 bit / any bit) (#285)
parent
46415f5a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
5 deletions
+83
-5
awq/models/base.py
awq/models/base.py
+23
-4
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+12
-1
examples/awq_to_gguf_quant.py
examples/awq_to_gguf_quant.py
+48
-0
No files found.
awq/models/base.py
View file @
a3db8099
...
...
@@ -83,17 +83,36 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
,
duo_scaling
=
True
,
modules_to_not_convert
=
None
):
split
=
"train"
,
text_column
=
"text"
,
duo_scaling
=
True
,
modules_to_not_convert
=
None
,
export_compatible
=
False
):
self
.
quant_config
:
AwqConfig
=
AwqConfig
.
from_dict
(
quant_config
)
quantizer
=
AwqQuantizer
(
self
.
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
self
.
quant_config
.
w_bit
,
self
.
quant_config
.
q_group_size
,
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
,
export_compatible
=
export_compatible
)
quantizer
.
quantize
()
self
.
quantizer
.
quantize
()
self
.
is_quantized
=
True
@
torch
.
no_grad
()
def
pack
(
self
):
"""
A utility function for the following scenario. Note that save_quantized will
overwrite existing weights if you use the same quant_path.
model.quantize(
tokenizer,
quant_config=quant_config,
export_compatible=True
)
model.save_quantized(...) # produces GGUF/other compat weights
model.pack(...) # makes the model CUDA compat
model.save_quantized(...) # produces CUDA compat weights
"""
self
.
quantizer
.
pack
()
@
staticmethod
def
fuse_layers
(
model
):
pass
...
...
awq/quantize/quantizer.py
View file @
a3db8099
...
...
@@ -21,7 +21,8 @@ from awq.utils.module import (
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
None
)
->
None
:
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
None
,
export_compatible
=
False
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -32,6 +33,7 @@ class AwqQuantizer:
self
.
split
=
split
self
.
text_column
=
text_column
self
.
duo_scaling
=
duo_scaling
self
.
export_compatible
=
export_compatible
self
.
modules_to_not_convert
=
modules_to_not_convert
if
modules_to_not_convert
is
not
None
else
[]
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
...
...
@@ -115,6 +117,15 @@ class AwqQuantizer:
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
if
not
self
.
export_compatible
:
self
.
_apply_quant
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
def
pack
(
self
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"Packing"
):
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
named_linears
=
exclude_layers_to_not_quantize
(
named_linears
,
self
.
modules_to_not_convert
)
self
.
_apply_quant
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
...
...
examples/awq_to_gguf_quant.py
0 → 100644
View file @
a3db8099
import
os
import
subprocess
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
model_path
=
'mistralai/Mistral-7B-v0.1'
quant_path
=
'mistral-awq'
llama_cpp_path
=
'/workspace/llama.cpp'
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
6
,
"version"
:
"GEMM"
}
# Load model
# NOTE: pass safetensors=True to load safetensors
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
**
{
"low_cpu_mem_usage"
:
True
,
"use_cache"
:
False
}
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ
# after quantizing. The saved model is FP16 but has the AWQ scales applied.
model
.
quantize
(
tokenizer
,
quant_config
=
quant_config
,
export_compatible
=
True
)
# Save quantized model
model
.
save_quantized
(
quant_path
)
tokenizer
.
save_pretrained
(
quant_path
)
print
(
f
'Model is quantized and saved at "
{
quant_path
}
"'
)
# GGUF conversion
print
(
'Converting model to GGUF...'
)
llama_cpp_method
=
"q4_K_M"
convert_cmd_path
=
os
.
path
.
join
(
llama_cpp_path
,
"convert.py"
)
quantize_cmd_path
=
os
.
path
.
join
(
llama_cpp_path
,
"quantize"
)
if
not
os
.
path
.
exists
(
llama_cpp_path
):
cmd
=
f
"git clone https://github.com/ggerganov/llama.cpp.git
{
llama_cpp_path
}
&& cd
{
llama_cpp_path
}
&& make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1"
subprocess
.
run
([
cmd
],
shell
=
True
,
check
=
True
)
subprocess
.
run
([
f
"python
{
convert_cmd_path
}
{
quant_path
}
--outfile
{
quant_path
}
/model.gguf"
],
shell
=
True
,
check
=
True
)
subprocess
.
run
([
f
"
{
quantize_cmd_path
}
{
quant_path
}
/model.gguf
{
quant_path
}
/model_
{
llama_cpp_method
}
.gguf
{
llama_cpp_method
}
"
],
shell
=
True
,
check
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment