Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
68ae3be7
Unverified
Commit
68ae3be7
authored
Nov 13, 2023
by
Lysandre Debut
Committed by
GitHub
Nov 13, 2023
Browse files
Fix `from_pt` flag when loading with `safetensors` (#27394)
* Fix * Tests * Fix
parent
9dc8fe1b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
1 deletion
+67
-1
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+7
-1
tests/models/mpnet/test_modeling_mpnet.py
tests/models/mpnet/test_modeling_mpnet.py
+4
-0
tests/models/wav2vec2/test_modeling_wav2vec2.py
tests/models/wav2vec2/test_modeling_wav2vec2.py
+6
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+50
-0
No files found.
src/transformers/modeling_tf_pytorch_utils.py
View file @
68ae3be7
...
@@ -166,6 +166,7 @@ def load_pytorch_checkpoint_in_tf2_model(
...
@@ -166,6 +166,7 @@ def load_pytorch_checkpoint_in_tf2_model(
try
:
try
:
import
tensorflow
as
tf
# noqa: F401
import
tensorflow
as
tf
# noqa: F401
import
torch
# noqa: F401
import
torch
# noqa: F401
from
safetensors.torch
import
load_file
as
safe_load_file
# noqa: F401
except
ImportError
:
except
ImportError
:
logger
.
error
(
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
...
@@ -182,7 +183,12 @@ def load_pytorch_checkpoint_in_tf2_model(
...
@@ -182,7 +183,12 @@ def load_pytorch_checkpoint_in_tf2_model(
for
path
in
pytorch_checkpoint_path
:
for
path
in
pytorch_checkpoint_path
:
pt_path
=
os
.
path
.
abspath
(
path
)
pt_path
=
os
.
path
.
abspath
(
path
)
logger
.
info
(
f
"Loading PyTorch weights from
{
pt_path
}
"
)
logger
.
info
(
f
"Loading PyTorch weights from
{
pt_path
}
"
)
pt_state_dict
.
update
(
torch
.
load
(
pt_path
,
map_location
=
"cpu"
))
if
pt_path
.
endswith
(
".safetensors"
):
state_dict
=
safe_load_file
(
pt_path
)
else
:
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
"cpu"
)
pt_state_dict
.
update
(
state_dict
)
logger
.
info
(
f
"PyTorch checkpoint contains
{
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
()):,
}
parameters"
)
logger
.
info
(
f
"PyTorch checkpoint contains
{
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
()):,
}
parameters"
)
...
...
tests/models/mpnet/test_modeling_mpnet.py
View file @
68ae3be7
...
@@ -246,6 +246,10 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
...
@@ -246,6 +246,10 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_mpnet_for_question_answering
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_mpnet_for_question_answering
(
*
config_and_inputs
)
@
unittest
.
skip
(
"This isn't passing but should, seems like a misconfiguration of tied weights."
)
def
test_tf_from_pt_safetensors
(
self
):
return
@
require_torch
@
require_torch
class
MPNetModelIntegrationTest
(
unittest
.
TestCase
):
class
MPNetModelIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/wav2vec2/test_modeling_wav2vec2.py
View file @
68ae3be7
...
@@ -824,6 +824,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
...
@@ -824,6 +824,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# (Even with this call, there are still memory leak by ~0.04MB)
# (Even with this call, there are still memory leak by ~0.04MB)
self
.
clear_torch_jit_class_registry
()
self
.
clear_torch_jit_class_registry
()
@
unittest
.
skip
(
"Need to investigate why config.do_stable_layer_norm is set to False here when it doesn't seem to be supported"
)
def
test_flax_from_pt_safetensors
(
self
):
return
@
require_torch
@
require_torch
class
Wav2Vec2RobustModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
Wav2Vec2RobustModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
tests/test_modeling_common.py
View file @
68ae3be7
...
@@ -105,6 +105,7 @@ if is_tf_available():
...
@@ -105,6 +105,7 @@ if is_tf_available():
if
is_flax_available
():
if
is_flax_available
():
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
tests.test_modeling_flax_utils
import
check_models_equal
from
transformers.modeling_flax_pytorch_utils
import
(
from
transformers.modeling_flax_pytorch_utils
import
(
convert_pytorch_state_dict_to_flax
,
convert_pytorch_state_dict_to_flax
,
load_flax_weights_in_pytorch_model
,
load_flax_weights_in_pytorch_model
,
...
@@ -3219,6 +3220,55 @@ class ModelTesterMixin:
...
@@ -3219,6 +3220,55 @@ class ModelTesterMixin:
# with attention mask
# with attention mask
_
=
model
(
dummy_input
,
attention_mask
=
dummy_attention_mask
)
_
=
model
(
dummy_input
,
attention_mask
=
dummy_attention_mask
)
@
is_pt_tf_cross_test
def
test_tf_from_pt_safetensors
(
self
):
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
tf_model_class_name
=
"TF"
+
model_class
.
__name__
# Add the "TF" at the beginning
if
not
hasattr
(
transformers
,
tf_model_class_name
):
# transformers does not have this model in TF version yet
return
tf_model_class
=
getattr
(
transformers
,
tf_model_class_name
)
pt_model
=
model_class
(
config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
True
)
tf_model_1
=
tf_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
False
)
tf_model_2
=
tf_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
# Check models are equal
for
p1
,
p2
in
zip
(
tf_model_1
.
weights
,
tf_model_2
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
@
is_pt_flax_cross_test
def
test_flax_from_pt_safetensors
(
self
):
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
flax_model_class_name
=
"Flax"
+
model_class
.
__name__
# Add the "Flax at the beginning
if
not
hasattr
(
transformers
,
flax_model_class_name
):
# transformers does not have this model in Flax version yet
return
flax_model_class
=
getattr
(
transformers
,
flax_model_class_name
)
pt_model
=
model_class
(
config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
True
)
flax_model_1
=
flax_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
pt_model
.
save_pretrained
(
tmpdirname
,
safe_serialization
=
False
)
flax_model_2
=
flax_model_class
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
# Check models are equal
self
.
assertTrue
(
check_models_equal
(
flax_model_1
,
flax_model_2
))
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment