Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
db8ba322
Commit
db8ba322
authored
Sep 13, 2023
by
Casper Hansen
Browse files
Add safetensors support
parent
97d38e29
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
20 deletions
+22
-20
awq/models/auto.py
awq/models/auto.py
+4
-3
awq/models/base.py
awq/models/base.py
+18
-17
No files found.
awq/models/auto.py
View file @
db8ba322
...
...
@@ -35,12 +35,13 @@ class AutoAWQForCausalLM:
)
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
max_new_tokens
=
None
,
def
from_quantized
(
self
,
quant_path
,
quant_filename
=
'pytorch_model.bin'
,
max_new_tokens
=
None
,
device
=
'balanced'
,
trust_remote_code
=
True
,
fuse_layers
=
True
,
batch_size
=
1
)
->
BaseAWQForCausalLM
:
batch_size
=
1
,
use_safetensors
=
False
)
->
BaseAWQForCausalLM
:
os
.
environ
[
"AWQ_BATCH_SIZE"
]
=
str
(
batch_size
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
,
safetensors
=
use_safetensors
)
\ No newline at end of file
awq/models/base.py
View file @
db8ba322
...
...
@@ -11,6 +11,7 @@ from safetensors.torch import save_file
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
...
...
@@ -18,7 +19,7 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_
and_dispatch
,
infer_auto_device_map
from
accelerate
import
init_empty_weights
,
load_checkpoint_
in_model
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
(
nn
.
Module
):
...
...
@@ -217,7 +218,7 @@ class BaseAWQForCausalLM(nn.Module):
return
awq_results
def
save_quantized
(
self
,
save_dir
,
use_safetensors
=
False
,
shard_size
=
"10GB"
):
def
_save_files
(
save_dir
,
model_name
,
search_result
=
None
):
def
_save_files
(
save_dir
,
model_name
=
''
,
search_result
=
None
):
class
EmptyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
...
...
@@ -232,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module):
torch
.
save
(
search_result
,
f
'
{
save_dir
}
/
{
model_name
}
'
)
else
:
# model_name has no extension, add it when saving state_dict
model_name
+
=
'.safetensors'
if
use_safetensors
else
'.bin'
model_name
=
'
model
.safetensors'
if
use_safetensors
else
'
pytorch_model
.bin'
# shard checkpoint into chunks (10GB default)
shards
,
index
=
shard_checkpoint
(
...
...
@@ -262,8 +263,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model
if
self
.
search_result
is
None
or
self
.
is_quantized
:
model_name
=
f
'awq_model_w
{
self
.
quant_config
[
"w_bit"
]
}
_g
{
self
.
quant_config
[
"q_group_size"
]
}
'
_save_files
(
save_dir
,
model_name
,
search_result
=
None
)
_save_files
(
save_dir
,
''
,
search_result
=
None
)
else
:
model_name
=
'awq_model_search_result.pt'
_save_files
(
save_dir
,
model_name
,
self
.
search_result
)
...
...
@@ -284,9 +284,10 @@ class BaseAWQForCausalLM(nn.Module):
)
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
):
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
=
'pytorch_model.bin'
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
):
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
...
...
@@ -297,8 +298,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path
=
snapshot_download
(
model_path
,
ignore_patterns
=
ignore_patterns
)
# TODO: Better naming, model_filename becomes a directory
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
model_weights_path
=
model_path
+
f
'/
{
model_filename
}
'
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
...
...
@@ -341,13 +341,14 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights
if
is_quantized
:
model
=
load_checkpoint_
and_dispatch
(
load_checkpoint_
in_model
(
model
,
model_filename
,
device_map
=
device_map
,
no_split_module_classes
=
[
self
.
layer_type
]
checkpoint
=
model_path
if
safetensors
else
model_weights_path
,
device_map
=
device_map
)
model
=
simple_dispatch_model
(
model
,
device_map
)
if
fuse_layers
:
self
.
fuse_layers
(
model
,
quant_config
)
...
...
@@ -357,7 +358,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_
filename
,
model_
weights_path
,
device_map
=
device_map
,
trust_remote_code
=
trust_remote_code
,
offload_folder
=
"offload"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment