Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
471f811b
Commit
471f811b
authored
Aug 18, 2023
by
Casper Hansen
Browse files
Increase robustness of model loading
parent
916fdf97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
7 deletions
+33
-7
awq/models/base.py
awq/models/base.py
+33
-7
No files found.
awq/models/base.py
View file @
471f811b
...
...
@@ -2,18 +2,19 @@ import os
import
gc
import
torch
import
functools
import
accelerate
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
accelerate
import
init_empty_weights
,
load_checkpoint_and_dispatch
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_and_dispatch
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
:
...
...
@@ -215,10 +216,17 @@ class BaseAWQForCausalLM:
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
is_quantized
=
True
):
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
# Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
model_path
=
snapshot_download
(
model_path
)
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
if
safetensors
:
ignore_patterns
.
extend
([
"*.pt"
,
"*.bin"
])
else
:
ignore_patterns
.
append
(
"*safetensors*"
)
model_path
=
snapshot_download
(
model_path
,
ignore_patterns
=
ignore_patterns
)
# TODO: Better naming, model_filename becomes a directory
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
...
...
@@ -230,15 +238,33 @@ class BaseAWQForCausalLM:
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
)
model
.
tie_weights
()
# Load model weights
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
try
:
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
except
Exception
as
ex
:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print
(
f
'
{
ex
}
- falling back to AutoModelForCausalLM.from_pretrained'
)
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
del
model
# Load model weights
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_filename
,
device_map
=
device_map
,
offload_folder
=
"offload"
,
offload_state_dict
=
True
,
torch_dtype
=
torch_dtype
)
model
.
eval
()
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment