Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
693667b8
"tests/vscode:/vscode.git/clone" did not exist on "07c54413ac24f891fc37920f6c61ad8b7b035dc3"
Unverified
Commit
693667b8
authored
Feb 08, 2024
by
Matt
Committed by
GitHub
Feb 08, 2024
Browse files
Remove dead TF loading code (#28926)
Remove dead code
parent
115ac94d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
50 deletions
+0
-50
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+0
-50
No files found.
src/transformers/modeling_tf_utils.py
View file @
693667b8
...
...
@@ -32,7 +32,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import
h5py
import
numpy
as
np
import
tensorflow
as
tf
from
huggingface_hub
import
Repository
,
list_repo_files
from
packaging.version
import
parse
from
.
import
DataCollatorWithPadding
,
DefaultDataCollator
...
...
@@ -1356,55 +1355,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
with
open
(
extra_data_path
,
"wb"
)
as
f
:
pickle
.
dump
(
extra_data
,
f
)
def
load_repo_checkpoint
(
self
,
repo_path_or_name
):
"""
Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
the checkpoint was made.
Args:
repo_path_or_name (`str`):
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder).
Returns:
`dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
"""
if
getattr
(
self
,
"optimizer"
,
None
)
is
None
:
raise
RuntimeError
(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if
os
.
path
.
isdir
(
repo_path_or_name
):
local_dir
=
repo_path_or_name
else
:
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files
=
list_repo_files
(
repo_path_or_name
)
for
file
in
(
"checkpoint/weights.h5"
,
"checkpoint/extra_data.pickle"
):
if
file
not
in
repo_files
:
raise
FileNotFoundError
(
f
"Repo
{
repo_path_or_name
}
does not contain checkpoint file
{
file
}
!"
)
repo
=
Repository
(
repo_path_or_name
.
split
(
"/"
)[
-
1
],
clone_from
=
repo_path_or_name
)
local_dir
=
repo
.
local_dir
# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir
=
os
.
path
.
join
(
local_dir
,
"checkpoint"
)
weights_file
=
os
.
path
.
join
(
checkpoint_dir
,
"weights.h5"
)
if
not
os
.
path
.
isfile
(
weights_file
):
raise
FileNotFoundError
(
f
"Could not find checkpoint file weights.h5 in repo
{
repo_path_or_name
}
!"
)
extra_data_file
=
os
.
path
.
join
(
checkpoint_dir
,
"extra_data.pickle"
)
if
not
os
.
path
.
isfile
(
extra_data_file
):
raise
FileNotFoundError
(
f
"Could not find checkpoint file extra_data.pickle in repo
{
repo_path_or_name
}
!"
)
# Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
# The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
self
.
load_weights
(
weights_file
)
with
open
(
extra_data_file
,
"rb"
)
as
f
:
extra_data
=
pickle
.
load
(
f
)
self
.
optimizer
.
set_weights
(
extra_data
[
"optimizer_state"
])
# Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
# set it directly, but the user can pass it to fit().
return
{
"epoch"
:
extra_data
[
"epoch"
]}
def
prepare_tf_dataset
(
self
,
dataset
:
"datasets.Dataset"
,
# noqa:F821
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment