Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a638de19
Unverified
Commit
a638de19
authored
Jan 26, 2024
by
Yih-Dar
Committed by
GitHub
Jan 26, 2024
Browse files
Fix `weights_only` (#28725)
fix Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
d6ac8f4a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
20 additions
and
15 deletions
+20
-15
src/transformers/convert_pytorch_checkpoint_to_tf2.py
src/transformers/convert_pytorch_checkpoint_to_tf2.py
+2
-1
src/transformers/modeling_flax_pytorch_utils.py
src/transformers/modeling_flax_pytorch_utils.py
+4
-2
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+2
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+4
-6
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+2
-1
src/transformers/trainer.py
src/transformers/trainer.py
+6
-4
No files found.
src/transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
a638de19
...
@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
...
@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
if
compare_with_pt_model
:
if
compare_with_pt_model
:
tfo
=
tf_model
(
tf_model
.
dummy_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_model
.
dummy_inputs
,
training
=
False
)
# build the network
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
pytorch_checkpoint_path
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
pt_model
=
pt_model_class
.
from_pretrained
(
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
...
...
src/transformers/modeling_flax_pytorch_utils.py
View file @
a638de19
...
@@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
...
@@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
)
raise
raise
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
)
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
,
**
weights_only_kwarg
)
logger
.
info
(
f
"PyTorch checkpoint contains
{
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
()):,
}
parameters."
)
logger
.
info
(
f
"PyTorch checkpoint contains
{
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
()):,
}
parameters."
)
flax_state_dict
=
convert_pytorch_state_dict_to_flax
(
pt_state_dict
,
flax_model
)
flax_state_dict
=
convert_pytorch_state_dict_to_flax
(
pt_state_dict
,
flax_model
)
...
@@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
...
@@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict
=
{}
flax_state_dict
=
{}
for
shard_file
in
shard_filenames
:
for
shard_file
in
shard_filenames
:
# load using msgpack utils
# load using msgpack utils
pt_state_dict
=
torch
.
load
(
shard_file
,
weights_only
=
is_torch_greater_or_equal_than_1_13
)
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
pt_state_dict
=
torch
.
load
(
shard_file
,
**
weights_only_kwarg
)
pt_state_dict
=
{
k
:
v
.
numpy
()
for
k
,
v
in
pt_state_dict
.
items
()}
pt_state_dict
=
{
k
:
v
.
numpy
()
for
k
,
v
in
pt_state_dict
.
items
()}
model_prefix
=
flax_model
.
base_model_prefix
model_prefix
=
flax_model
.
base_model_prefix
...
...
src/transformers/modeling_tf_pytorch_utils.py
View file @
a638de19
...
@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
...
@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
if
pt_path
.
endswith
(
".safetensors"
):
if
pt_path
.
endswith
(
".safetensors"
):
state_dict
=
safe_load_file
(
pt_path
)
state_dict
=
safe_load_file
(
pt_path
)
else
:
else
:
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
)
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
,
**
weights_only_kwarg
)
pt_state_dict
.
update
(
state_dict
)
pt_state_dict
.
update
(
state_dict
)
...
...
src/transformers/modeling_utils.py
View file @
a638de19
...
@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
...
@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message
+=
f
"
\n
Missing key(s):
{
str_unexpected_keys
}
."
error_message
+=
f
"
\n
Missing key(s):
{
str_unexpected_keys
}
."
raise
RuntimeError
(
error_message
)
raise
RuntimeError
(
error_message
)
loader
=
(
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
safe_load_file
loader
=
safe_load_file
if
load_safe
else
partial
(
torch
.
load
,
map_location
=
"cpu"
,
**
weights_only_kwarg
)
if
load_safe
else
partial
(
torch
.
load
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
)
)
for
shard_file
in
shard_files
:
for
shard_file
in
shard_files
:
state_dict
=
loader
(
os
.
path
.
join
(
folder
,
shard_file
))
state_dict
=
loader
(
os
.
path
.
join
(
folder
,
shard_file
))
...
@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
...
@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and
is_zipfile
(
checkpoint_file
)
and
is_zipfile
(
checkpoint_file
)
):
):
extra_args
=
{
"mmap"
:
True
}
extra_args
=
{
"mmap"
:
True
}
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
return
torch
.
load
(
return
torch
.
load
(
checkpoint_file
,
checkpoint_file
,
map_location
=
map_location
,
map_location
=
map_location
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
**
extra_args
,
**
extra_args
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
a638de19
...
@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
...
@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
)
)
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
weight_path
,
weight_path
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
except
EnvironmentError
:
except
EnvironmentError
:
...
...
src/transformers/trainer.py
View file @
a638de19
...
@@ -2102,6 +2102,7 @@ class Trainer:
...
@@ -2102,6 +2102,7 @@ class Trainer:
)
)
if
os
.
path
.
isfile
(
weights_file
)
or
os
.
path
.
isfile
(
safe_weights_file
)
or
is_fsdp_ckpt
:
if
os
.
path
.
isfile
(
weights_file
)
or
os
.
path
.
isfile
(
safe_weights_file
)
or
is_fsdp_ckpt
:
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
# If the model is on the GPU, it still works!
# If the model is on the GPU, it still works!
if
is_sagemaker_mp_enabled
():
if
is_sagemaker_mp_enabled
():
if
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
"user_content.pt"
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
"user_content.pt"
)):
...
@@ -2120,7 +2121,7 @@ class Trainer:
...
@@ -2120,7 +2121,7 @@ class Trainer:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
weights_file
,
weights_file
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict
[
"_smp_is_partial"
]
=
False
state_dict
[
"_smp_is_partial"
]
=
False
...
@@ -2137,7 +2138,7 @@ class Trainer:
...
@@ -2137,7 +2138,7 @@ class Trainer:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
weights_file
,
weights_file
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
...
@@ -2190,6 +2191,7 @@ class Trainer:
...
@@ -2190,6 +2191,7 @@ class Trainer:
or
os
.
path
.
exists
(
best_safe_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
)
):
):
has_been_loaded
=
True
has_been_loaded
=
True
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_greater_or_equal_than_1_13
else
{}
if
is_sagemaker_mp_enabled
():
if
is_sagemaker_mp_enabled
():
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
# If the 'user_content.pt' file exists, load with the new smp api.
# If the 'user_content.pt' file exists, load with the new smp api.
...
@@ -2209,7 +2211,7 @@ class Trainer:
...
@@ -2209,7 +2211,7 @@ class Trainer:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
best_model_path
,
best_model_path
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
state_dict
[
"_smp_is_partial"
]
=
False
state_dict
[
"_smp_is_partial"
]
=
False
...
@@ -2242,7 +2244,7 @@ class Trainer:
...
@@ -2242,7 +2244,7 @@ class Trainer:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
best_model_path
,
best_model_path
,
map_location
=
"cpu"
,
map_location
=
"cpu"
,
weights_only
=
is_torch_greater_or_equal_than_1_13
,
**
weights_only
_kwarg
,
)
)
# If the model is on the GPU, it still works!
# If the model is on the GPU, it still works!
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment