Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
471f811b
Commit
471f811b
authored
Aug 18, 2023
by
Casper Hansen
Browse files
Increase robustness of model loading
parent
916fdf97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
7 deletions
+33
-7
awq/models/base.py
awq/models/base.py
+33
-7
No files found.
awq/models/base.py
View file @
471f811b
...
@@ -2,18 +2,19 @@ import os
...
@@ -2,18 +2,19 @@ import os
import
gc
import
gc
import
torch
import
torch
import
functools
import
functools
import
accelerate
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
accelerate
import
init_empty_weights
,
load_checkpoint_and_dispatch
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_and_dispatch
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
...
@@ -215,10 +216,17 @@ class BaseAWQForCausalLM:
...
@@ -215,10 +216,17 @@ class BaseAWQForCausalLM:
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
q_config
=
{},
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
is_quantized
=
True
):
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
# Download model if path is not a directory
# Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
if
not
os
.
path
.
isdir
(
model_path
):
model_path
=
snapshot_download
(
model_path
)
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
if
safetensors
:
ignore_patterns
.
extend
([
"*.pt"
,
"*.bin"
])
else
:
ignore_patterns
.
append
(
"*safetensors*"
)
model_path
=
snapshot_download
(
model_path
,
ignore_patterns
=
ignore_patterns
)
# TODO: Better naming, model_filename becomes a directory
# TODO: Better naming, model_filename becomes a directory
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
...
@@ -230,15 +238,33 @@ class BaseAWQForCausalLM:
...
@@ -230,15 +238,33 @@ class BaseAWQForCausalLM:
with
init_empty_weights
():
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
)
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
)
model
.
tie_weights
()
model
.
tie_weights
()
# Load model weights
# Load model weights
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
try
:
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
except
Exception
as
ex
:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print
(
f
'
{
ex
}
- falling back to AutoModelForCausalLM.from_pretrained'
)
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
del
model
# Load model weights
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_filename
,
device_map
=
device_map
,
offload_folder
=
"offload"
,
offload_state_dict
=
True
,
torch_dtype
=
torch_dtype
)
model
.
eval
()
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment