Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3c62b5d
Unverified
Commit
c3c62b5d
authored
Jun 15, 2022
by
Joao Gante
Committed by
GitHub
Jun 15, 2022
Browse files
CLI: Add flag to push TF weights directly into main (#17720)
* Add flag to push weights directly into main
parent
6ebeeeef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+17
-6
No files found.
src/transformers/commands/pt_to_tf.py
View file @
c3c62b5d
...
...
@@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return
PTtoTFCommand
(
args
.
model_name
,
args
.
local_dir
,
args
.
n
o_pr
,
args
.
new_weights
)
return
PTtoTFCommand
(
args
.
model_name
,
args
.
local_dir
,
args
.
n
ew_weights
,
args
.
no_pr
,
args
.
push
)
class
PTtoTFCommand
(
BaseTransformersCLICommand
):
...
...
@@ -76,13 +76,18 @@ class PTtoTFCommand(BaseTransformersCLICommand):
default
=
""
,
help
=
"Optional local directory of the model repository. Defaults to /tmp/{model_name}"
,
)
train_parser
.
add_argument
(
"--new-weights"
,
action
=
"store_true"
,
help
=
"Optional flag to create new TensorFlow weights, even if they already exist."
,
)
train_parser
.
add_argument
(
"--no-pr"
,
action
=
"store_true"
,
help
=
"Optional flag to NOT open a PR with converted weights."
)
train_parser
.
add_argument
(
"--
new-weights
"
,
"--
push
"
,
action
=
"store_true"
,
help
=
"Optional flag to
create new TensorFlow weights, even if they already exist.
"
,
help
=
"Optional flag to
push the weights directly to `main` (requires permissions)
"
,
)
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
...
...
@@ -129,12 +134,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return
_find_pt_tf_differences
(
pt_outputs
,
tf_outputs
,
{})
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
new_weights
:
bool
,
no_pr
:
bool
,
push
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_model_name
=
model_name
self
.
_local_dir
=
local_dir
if
local_dir
else
os
.
path
.
join
(
"/tmp"
,
model_name
)
self
.
_no_pr
=
no_pr
self
.
_new_weights
=
new_weights
self
.
_no_pr
=
no_pr
self
.
_push
=
push
def
get_text_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
...
...
@@ -234,7 +240,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
)
)
if
not
self
.
_no_pr
:
if
self
.
_push
:
repo
.
git_add
(
auto_lfs_track
=
True
)
repo
.
git_commit
(
"Add TF weights"
)
repo
.
git_push
(
blocking
=
True
)
# this prints a progress bar with the upload
self
.
_logger
.
warn
(
f
"TF weights pushed into
{
self
.
_model_name
}
"
)
elif
not
self
.
_no_pr
:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment