Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c5be7cae
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "371b572e5504f72024249858861743834c8924b2"
Unverified
Commit
c5be7cae
authored
Sep 02, 2022
by
Stas Bekman
Committed by
GitHub
Sep 02, 2022
Browse files
postpone bnb load until it's needed (#18859)
parent
9e346f74
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+8
-2
No files found.
src/transformers/modeling_utils.py
View file @
c5be7cae
...
@@ -84,8 +84,6 @@ if is_accelerate_available():
...
@@ -84,8 +84,6 @@ if is_accelerate_available():
else
:
else
:
get_balanced_memory
=
None
get_balanced_memory
=
None
if
is_bitsandbytes_available
():
from
.utils.bitsandbytes
import
get_key_to_not_convert
,
replace_8bit_linear
,
set_module_8bit_tensor_to_device
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -527,6 +525,9 @@ def _load_state_dict_into_meta_model(
...
@@ -527,6 +525,9 @@ def _load_state_dict_into_meta_model(
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
# they won't get loaded.
if
load_in_8bit
:
from
.utils.bitsandbytes
import
set_module_8bit_tensor_to_device
error_msgs
=
[]
error_msgs
=
[]
old_keys
=
[]
old_keys
=
[]
...
@@ -2142,6 +2143,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2142,6 +2143,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
if
load_in_8bit
:
if
load_in_8bit
:
from
.utils.bitsandbytes
import
get_key_to_not_convert
,
replace_8bit_linear
logger
.
info
(
"Detected 8-bit loading: activating 8-bit loading for this model"
)
logger
.
info
(
"Detected 8-bit loading: activating 8-bit loading for this model"
)
# We never convert lm_head or any last modules for numerical stability reasons
# We never convert lm_head or any last modules for numerical stability reasons
...
@@ -2279,6 +2282,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2279,6 +2282,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype
=
None
,
dtype
=
None
,
load_in_8bit
=
False
,
load_in_8bit
=
False
,
):
):
if
load_in_8bit
:
from
.utils.bitsandbytes
import
set_module_8bit_tensor_to_device
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
():
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
():
if
offload_folder
is
None
:
if
offload_folder
is
None
:
raise
ValueError
(
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment