Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
693667b8
Unverified
Commit
693667b8
authored
Feb 08, 2024
by
Matt
Committed by
GitHub
Feb 08, 2024
Browse files
Remove dead TF loading code (#28926)
Remove dead code
parent
115ac94d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
50 deletions
+0
-50
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+0
-50
No files found.
src/transformers/modeling_tf_utils.py
View file @
693667b8
...
...
@@ -32,7 +32,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import
h5py
import
numpy
as
np
import
tensorflow
as
tf
from
huggingface_hub
import
Repository
,
list_repo_files
from
packaging.version
import
parse
from
.
import
DataCollatorWithPadding
,
DefaultDataCollator
...
...
@@ -1356,55 +1355,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
with
open
(
extra_data_path
,
"wb"
)
as
f
:
pickle
.
dump
(
extra_data
,
f
)
def
load_repo_checkpoint
(
self
,
repo_path_or_name
):
"""
Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
the checkpoint was made.
Args:
repo_path_or_name (`str`):
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder).
Returns:
`dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
"""
if
getattr
(
self
,
"optimizer"
,
None
)
is
None
:
raise
RuntimeError
(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if
os
.
path
.
isdir
(
repo_path_or_name
):
local_dir
=
repo_path_or_name
else
:
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files
=
list_repo_files
(
repo_path_or_name
)
for
file
in
(
"checkpoint/weights.h5"
,
"checkpoint/extra_data.pickle"
):
if
file
not
in
repo_files
:
raise
FileNotFoundError
(
f
"Repo
{
repo_path_or_name
}
does not contain checkpoint file
{
file
}
!"
)
repo
=
Repository
(
repo_path_or_name
.
split
(
"/"
)[
-
1
],
clone_from
=
repo_path_or_name
)
local_dir
=
repo
.
local_dir
# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir
=
os
.
path
.
join
(
local_dir
,
"checkpoint"
)
weights_file
=
os
.
path
.
join
(
checkpoint_dir
,
"weights.h5"
)
if
not
os
.
path
.
isfile
(
weights_file
):
raise
FileNotFoundError
(
f
"Could not find checkpoint file weights.h5 in repo
{
repo_path_or_name
}
!"
)
extra_data_file
=
os
.
path
.
join
(
checkpoint_dir
,
"extra_data.pickle"
)
if
not
os
.
path
.
isfile
(
extra_data_file
):
raise
FileNotFoundError
(
f
"Could not find checkpoint file extra_data.pickle in repo
{
repo_path_or_name
}
!"
)
# Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
# The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
self
.
load_weights
(
weights_file
)
with
open
(
extra_data_file
,
"rb"
)
as
f
:
extra_data
=
pickle
.
load
(
f
)
self
.
optimizer
.
set_weights
(
extra_data
[
"optimizer_state"
])
# Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
# set it directly, but the user can pass it to fit().
return
{
"epoch"
:
extra_data
[
"epoch"
]}
def
prepare_tf_dataset
(
self
,
dataset
:
"datasets.Dataset"
,
# noqa:F821
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment