Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
68ae3be7
Unverified
Commit
68ae3be7
authored
Nov 13, 2023
by
Lysandre Debut
Committed by
GitHub
Nov 13, 2023
Browse files
Fix `from_pt` flag when loading with `safetensors` (#27394)
* Fix * Tests * Fix
parent
9dc8fe1b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
1 deletion
+67
-1
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+7
-1
tests/models/mpnet/test_modeling_mpnet.py
tests/models/mpnet/test_modeling_mpnet.py
+4
-0
tests/models/wav2vec2/test_modeling_wav2vec2.py
tests/models/wav2vec2/test_modeling_wav2vec2.py
+6
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+50
-0
No files found.
src/transformers/modeling_tf_pytorch_utils.py
View file @
68ae3be7
...
...
@@ -166,6 +166,7 @@ def load_pytorch_checkpoint_in_tf2_model(
try
:
import
tensorflow
as
tf
# noqa: F401
import
torch
# noqa: F401
from
safetensors.torch
import
load_file
as
safe_load_file
# noqa: F401
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
...
...
@@ -182,7 +183,12 @@ def load_pytorch_checkpoint_in_tf2_model(
for
path
in
pytorch_checkpoint_path
:
pt_path
=
os
.
path
.
abspath
(
path
)
logger
.
info
(
f
"Loading PyTorch weights from
{
pt_path
}
"
)
pt_state_dict
.
update
(
torch
.
load
(
pt_path
,
map_location
=
"cpu"
))
if
pt_path
.
endswith
(
".safetensors"
):
state_dict
=
safe_load_file
(
pt_path
)
else
:
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
)
pt_state_dict
.
update
(
state_dict
)
logger
.
info
(
f
"PyTorch checkpoint contains
{
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
()):,
}
parameters"
)
...
...
tests/models/mpnet/test_modeling_mpnet.py
View file @
68ae3be7
...
...
@@ -246,6 +246,10 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_mpnet_for_question_answering
(
*
config_and_inputs
)
@
unittest
.
skip
(
"This isn't passing but should, seems like a misconfiguration of tied weights."
)
def
test_tf_from_pt_safetensors
(
self
):
return
@
require_torch
class
MPNetModelIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/wav2vec2/test_modeling_wav2vec2.py
View file @
68ae3be7
...
...
@@ -824,6 +824,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# (Even with this call, there are still memory leak by ~0.04MB)
self
.
clear_torch_jit_class_registry
()
@
unittest
.
skip
(
"Need to investigate why config.do_stable_layer_norm is set to False here when it doesn't seem to be supported"
)
def
test_flax_from_pt_safetensors
(
self
):
return
@
require_torch
class
Wav2Vec2RobustModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
tests/test_modeling_common.py
View file @
68ae3be7
...
...
@@ -105,6 +105,7 @@ if is_tf_available():
if
is_flax_available
():
import
jax.numpy
as
jnp
from
tests.test_modeling_flax_utils
import
check_models_equal
from
transformers.modeling_flax_pytorch_utils
import
(
convert_pytorch_state_dict_to_flax
,
load_flax_weights_in_pytorch_model
,
...
...
@@ -3219,6 +3220,55 @@ class ModelTesterMixin:
# with attention mask
_
=
model
(
dummy_input
,
attention_mask
=
dummy_attention_mask
)
@
is_pt_tf_cross_test
def
test_tf_from_pt_safetensors
(
self
):
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
tf_model_class_name
=
"TF"
+
model_class
.
__name__
# Add the "TF" at the beginning
if
not
hasattr
(
transformers
,
tf_model_class_name
):
# transformers does not have this model in TF version yet
return
tf_model_class
=
getattr
(
transformers
,
tf_model_class_name
)
pt_model
=
model_class
(
config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
True
)
tf_model_1
=
tf_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
False
)
tf_model_2
=
tf_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
# Check models are equal
for
p1
,
p2
in
zip
(
tf_model_1
.
weights
,
tf_model_2
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
@
is_pt_flax_cross_test
def
test_flax_from_pt_safetensors
(
self
):
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
flax_model_class_name
=
"Flax"
+
model_class
.
__name__
# Add the "Flax at the beginning
if
not
hasattr
(
transformers
,
flax_model_class_name
):
# transformers does not have this model in Flax version yet
return
flax_model_class
=
getattr
(
transformers
,
flax_model_class_name
)
pt_model
=
model_class
(
config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
True
)
flax_model_1
=
flax_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
False
)
flax_model_2
=
flax_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
# Check models are equal
self
.
assertTrue
(
check_models_equal
(
flax_model_1
,
flax_model_2
))
global_rng
=
random
.
Random
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment