Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3c62b5d
Unverified
Commit
c3c62b5d
authored
Jun 15, 2022
by
Joao Gante
Committed by
GitHub
Jun 15, 2022
Browse files
CLI: Add flag to push TF weights directly into main (#17720)
* Add flag to push weights directly into main
parent
6ebeeeef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+17
-6
No files found.
src/transformers/commands/pt_to_tf.py
View file @
c3c62b5d
...
...
@@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return
PTtoTFCommand
(
args
.
model_name
,
args
.
local_dir
,
args
.
n
o_pr
,
args
.
new_weights
)
return
PTtoTFCommand
(
args
.
model_name
,
args
.
local_dir
,
args
.
n
ew_weights
,
args
.
no_pr
,
args
.
push
)
class
PTtoTFCommand
(
BaseTransformersCLICommand
):
...
...
@@ -76,13 +76,18 @@ class PTtoTFCommand(BaseTransformersCLICommand):
default
=
""
,
help
=
"Optional local directory of the model repository. Defaults to /tmp/{model_name}"
,
)
train_parser
.
add_argument
(
"--new-weights"
,
action
=
"store_true"
,
help
=
"Optional flag to create new TensorFlow weights, even if they already exist."
,
)
train_parser
.
add_argument
(
"--no-pr"
,
action
=
"store_true"
,
help
=
"Optional flag to NOT open a PR with converted weights."
)
train_parser
.
add_argument
(
"--
new-weights
"
,
"--
push
"
,
action
=
"store_true"
,
help
=
"Optional flag to
create new TensorFlow weights, even if they already exist.
"
,
help
=
"Optional flag to
push the weights directly to `main` (requires permissions)
"
,
)
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
...
...
@@ -129,12 +134,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return
_find_pt_tf_differences
(
pt_outputs
,
tf_outputs
,
{})
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
new_weights
:
bool
,
no_pr
:
bool
,
push
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_model_name
=
model_name
self
.
_local_dir
=
local_dir
if
local_dir
else
os
.
path
.
join
(
"/tmp"
,
model_name
)
self
.
_no_pr
=
no_pr
self
.
_new_weights
=
new_weights
self
.
_no_pr
=
no_pr
self
.
_push
=
push
def
get_text_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
...
...
@@ -234,7 +240,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
)
)
if
not
self
.
_no_pr
:
if
self
.
_push
:
repo
.
git_add
(
auto_lfs_track
=
True
)
repo
.
git_commit
(
"Add TF weights"
)
repo
.
git_push
(
blocking
=
True
)
# this prints a progress bar with the upload
self
.
_logger
.
warn
(
f
"TF weights pushed into
{
self
.
_model_name
}
"
)
elif
not
self
.
_no_pr
:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment