Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ca1f1c86
Unverified
Commit
ca1f1c86
authored
Jun 01, 2022
by
Joao Gante
Committed by
GitHub
Jun 01, 2022
Browse files
CLI: tool to convert PT into TF weights and open hub PR (#17497)
parent
3766df4f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
208 additions
and
3 deletions
+208
-3
.circleci/config.yml
.circleci/config.yml
+4
-2
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+186
-0
src/transformers/commands/transformers_cli.py
src/transformers/commands/transformers_cli.py
+2
-0
tests/utils/test_cli.py
tests/utils/test_cli.py
+16
-1
No files found.
.circleci/config.yml
View file @
ca1f1c86
...
@@ -78,7 +78,8 @@ jobs:
...
@@ -78,7 +78,8 @@ jobs:
keys
:
keys
:
-
v0.4-torch_and_tf-{{ checksum "setup.py" }}
-
v0.4-torch_and_tf-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng git-lfs
-
run
:
git lfs install
-
run
:
pip install --upgrade pip
-
run
:
pip install --upgrade pip
-
run
:
pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
-
run
:
pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
-
run
:
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
-
run
:
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
...
@@ -117,7 +118,8 @@ jobs:
...
@@ -117,7 +118,8 @@ jobs:
keys
:
keys
:
-
v0.4-torch_and_tf-{{ checksum "setup.py" }}
-
v0.4-torch_and_tf-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng git-lfs
-
run
:
git lfs install
-
run
:
pip install --upgrade pip
-
run
:
pip install --upgrade pip
-
run
:
pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
-
run
:
pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
-
run
:
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
-
run
:
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
...
...
src/transformers/commands/pt_to_tf.py
0 → 100644
View file @
ca1f1c86
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
argparse
import
ArgumentParser
,
Namespace
import
numpy
as
np
from
datasets
import
load_dataset
from
huggingface_hub
import
Repository
,
upload_file
from
..
import
AutoFeatureExtractor
,
AutoModel
,
AutoTokenizer
,
TFAutoModel
,
is_tf_available
,
is_torch_available
from
..utils
import
logging
from
.
import
BaseTransformersCLICommand
if
is_tf_available
():
import
tensorflow
as
tf
tf
.
config
.
experimental
.
enable_tensor_float_32_execution
(
False
)
if
is_torch_available
():
import
torch
MAX_ERROR
=
5e-5
# larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME
=
"tf_model.h5"
def
convert_command_factory
(
args
:
Namespace
):
"""
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
Returns: ServeCommand
"""
return
PTtoTFCommand
(
args
.
model_name
,
args
.
local_dir
,
args
.
no_pr
)
class
PTtoTFCommand
(
BaseTransformersCLICommand
):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser
=
parser
.
add_parser
(
"pt-to-tf"
,
help
=
(
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
" Can also be used to validate existing weights without opening PRs, with --no-pr."
),
)
train_parser
.
add_argument
(
"--model-name"
,
type
=
str
,
required
=
True
,
help
=
"The model name, including owner/organization, as seen on the hub."
,
)
train_parser
.
add_argument
(
"--local-dir"
,
type
=
str
,
default
=
""
,
help
=
"Optional local directory of the model repository. Defaults to /tmp/{model_name}"
,
)
train_parser
.
add_argument
(
"--no-pr"
,
action
=
"store_true"
,
help
=
"Optional flag to NOT open a PR with converted weights."
)
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_model_name
=
model_name
self
.
_local_dir
=
local_dir
if
local_dir
else
os
.
path
.
join
(
"/tmp"
,
model_name
)
self
.
_no_pr
=
no_pr
def
get_text_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
sample_text
=
[
"Hi there!"
,
"I am a batch with more than one row and different input lengths."
]
if
tokenizer
.
pad_token
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
pt_input
=
tokenizer
(
sample_text
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
)
tf_input
=
tokenizer
(
sample_text
,
return_tensors
=
"tf"
,
padding
=
True
,
truncation
=
True
)
return
pt_input
,
tf_input
def
get_audio_inputs
(
self
):
processor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
num_samples
=
2
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
speech_samples
=
ds
.
sort
(
"id"
).
select
(
range
(
num_samples
))[:
num_samples
][
"audio"
]
raw_samples
=
[
x
[
"array"
]
for
x
in
speech_samples
]
pt_input
=
processor
(
raw_samples
,
return_tensors
=
"pt"
,
padding
=
True
)
tf_input
=
processor
(
raw_samples
,
return_tensors
=
"tf"
,
padding
=
True
)
return
pt_input
,
tf_input
def
get_image_inputs
(
self
):
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
num_samples
=
2
ds
=
load_dataset
(
"cifar10"
,
"plain_text"
,
split
=
"test"
)[:
num_samples
][
"img"
]
pt_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"pt"
)
tf_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"tf"
)
return
pt_input
,
tf_input
def
run
(
self
):
# Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo
=
Repository
(
local_dir
=
self
.
_local_dir
,
clone_from
=
self
.
_model_name
)
repo
.
git_pull
()
# in case the repo already exists locally, but with an older commit
# Load models and acquire a basic input for its modality.
pt_model
=
AutoModel
.
from_pretrained
(
self
.
_local_dir
)
main_input_name
=
pt_model
.
main_input_name
if
main_input_name
==
"input_ids"
:
pt_input
,
tf_input
=
self
.
get_text_inputs
()
elif
main_input_name
==
"pixel_values"
:
pt_input
,
tf_input
=
self
.
get_image_inputs
()
elif
main_input_name
==
"input_features"
:
pt_input
,
tf_input
=
self
.
get_audio_inputs
()
else
:
raise
ValueError
(
f
"Can't detect the model modality (`main_input_name` =
{
main_input_name
}
)"
)
tf_from_pt_model
=
TFAutoModel
.
from_pretrained
(
self
.
_local_dir
,
from_pt
=
True
)
# Extra input requirements, in addition to the input modality
if
hasattr
(
pt_model
,
"encoder"
)
and
hasattr
(
pt_model
,
"decoder"
):
decoder_input_ids
=
np
.
asarray
([[
1
],
[
1
]],
dtype
=
int
)
*
pt_model
.
config
.
decoder_start_token_id
pt_input
.
update
({
"decoder_input_ids"
:
torch
.
tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
pt_last_hidden_state
=
pt_model
(
**
pt_input
).
last_hidden_state
.
detach
().
numpy
()
tf_from_pt_last_hidden_state
=
tf_from_pt_model
(
**
tf_input
).
last_hidden_state
.
numpy
()
crossload_diff
=
np
.
max
(
np
.
abs
(
pt_last_hidden_state
-
tf_from_pt_last_hidden_state
))
if
crossload_diff
>=
MAX_ERROR
:
raise
ValueError
(
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference ="
f
"
{
crossload_diff
}
)"
)
# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
tf_weights_path
=
os
.
path
.
join
(
self
.
_local_dir
,
TF_WEIGHTS_NAME
)
if
not
os
.
path
.
exists
(
tf_weights_path
):
tf_from_pt_model
.
save_weights
(
tf_weights_path
)
del
tf_from_pt_model
,
pt_model
# will no longer be used, and may have a large memory footprint
tf_model
=
TFAutoModel
.
from_pretrained
(
self
.
_local_dir
)
tf_last_hidden_state
=
tf_model
(
**
tf_input
).
last_hidden_state
.
numpy
()
converted_diff
=
np
.
max
(
np
.
abs
(
pt_last_hidden_state
-
tf_last_hidden_state
))
if
converted_diff
>=
MAX_ERROR
:
raise
ValueError
(
"The converted TF model has different last hidden states, something went wrong! (max difference ="
f
"
{
converted_diff
}
)"
)
if
not
self
.
_no_pr
:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try
:
self
.
_logger
.
warn
(
"Uploading the weights into a new PR..."
)
hub_pr_url
=
upload_file
(
path_or_fileobj
=
tf_weights_path
,
path_in_repo
=
TF_WEIGHTS_NAME
,
repo_id
=
self
.
_model_name
,
create_pr
=
True
,
pr_commit_summary
=
"Add TF weights"
,
pr_commit_description
=
(
f
"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference=
{
crossload_diff
:.
3
e
}
;"
f
" Max converted hidden state difference=
{
converted_diff
:.
3
e
}
."
),
)
self
.
_logger
.
warn
(
f
"PR open in
{
hub_pr_url
}
"
)
except
TypeError
:
self
.
_logger
.
warn
(
f
"You can now open a PR in https://huggingface.co/
{
self
.
_model_name
}
/discussions, manually"
f
" uploading the file in
{
tf_weights_path
}
"
)
src/transformers/commands/transformers_cli.py
View file @
ca1f1c86
...
@@ -21,6 +21,7 @@ from .convert import ConvertCommand
...
@@ -21,6 +21,7 @@ from .convert import ConvertCommand
from
.download
import
DownloadCommand
from
.download
import
DownloadCommand
from
.env
import
EnvironmentCommand
from
.env
import
EnvironmentCommand
from
.lfs
import
LfsCommands
from
.lfs
import
LfsCommands
from
.pt_to_tf
import
PTtoTFCommand
from
.run
import
RunCommand
from
.run
import
RunCommand
from
.serving
import
ServeCommand
from
.serving
import
ServeCommand
from
.user
import
UserCommands
from
.user
import
UserCommands
...
@@ -40,6 +41,7 @@ def main():
...
@@ -40,6 +41,7 @@ def main():
AddNewModelCommand
.
register_subcommand
(
commands_parser
)
AddNewModelCommand
.
register_subcommand
(
commands_parser
)
AddNewModelLikeCommand
.
register_subcommand
(
commands_parser
)
AddNewModelLikeCommand
.
register_subcommand
(
commands_parser
)
LfsCommands
.
register_subcommand
(
commands_parser
)
LfsCommands
.
register_subcommand
(
commands_parser
)
PTtoTFCommand
.
register_subcommand
(
commands_parser
)
# Let's go
# Let's go
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/utils/test_cli.py
View file @
ca1f1c86
...
@@ -13,10 +13,12 @@
...
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
shutil
import
unittest
import
unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers.testing_utils
import
CaptureStd
from
transformers.testing_utils
import
CaptureStd
,
is_pt_tf_cross_test
class
CLITest
(
unittest
.
TestCase
):
class
CLITest
(
unittest
.
TestCase
):
...
@@ -30,3 +32,16 @@ class CLITest(unittest.TestCase):
...
@@ -30,3 +32,16 @@ class CLITest(unittest.TestCase):
self
.
assertIn
(
"Python version"
,
cs
.
out
)
self
.
assertIn
(
"Python version"
,
cs
.
out
)
self
.
assertIn
(
"Platform"
,
cs
.
out
)
self
.
assertIn
(
"Platform"
,
cs
.
out
)
self
.
assertIn
(
"Using distributed or parallel set-up in script?"
,
cs
.
out
)
self
.
assertIn
(
"Using distributed or parallel set-up in script?"
,
cs
.
out
)
@
is_pt_tf_cross_test
@
patch
(
"sys.argv"
,
[
"fakeprogrampath"
,
"pt-to-tf"
,
"--model-name"
,
"hf-internal-testing/tiny-random-gptj"
,
"--no-pr"
]
)
def
test_cli_pt_to_tf
(
self
):
import
transformers.commands.transformers_cli
shutil
.
rmtree
(
"/tmp/hf-internal-testing/tiny-random-gptj"
,
ignore_errors
=
True
)
# cleans potential past runs
transformers
.
commands
.
transformers_cli
.
main
()
# The original repo has no TF weights -- if they exist, they were created by the CLI
self
.
assertTrue
(
os
.
path
.
exists
(
"/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment