Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
143fd0b6
Commit
143fd0b6
authored
Jan 18, 2022
by
Le Hou
Committed by
A. Unique TensorFlower
Jan 18, 2022
Browse files
Minor bug fixes
PiperOrigin-RevId: 422637653
parent
871c4e0a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
4 deletions
+14
-4
official/core/base_task.py
official/core/base_task.py
+3
-1
official/nlp/tasks/dual_encoder.py
official/nlp/tasks/dual_encoder.py
+5
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+6
-2
No files found.
official/core/base_task.py
View file @
143fd0b6
...
...
@@ -101,9 +101,11 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
logging
.
info
(
"Trying to load pretrained checkpoint from %s"
,
ckpt_dir_or_file
)
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
"No checkpoint file found from %s. Will not load."
,
ckpt_dir_or_file
)
return
if
hasattr
(
model
,
"checkpoint_items"
):
...
...
official/nlp/tasks/dual_encoder.py
View file @
143fd0b6
...
...
@@ -187,9 +187,13 @@ class DualEncoderTask(base_task.Task):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
logging
.
info
(
'Trying to load pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
'No checkpoint file found from %s. Will not load.'
,
ckpt_dir_or_file
)
return
pretrain2finetune_mapping
=
{
...
...
official/nlp/tasks/sentence_prediction.py
View file @
143fd0b6
...
...
@@ -223,10 +223,14 @@ class SentencePredictionTask(base_task.Task):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
logging
.
info
(
'Trying to load pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
'No checkpoint file found from %s. Will not load.'
,
ckpt_dir_or_file
)
return
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment