Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
66cc634f
Commit
66cc634f
authored
Jun 30, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 30, 2020
Browse files
Update docstring to explicitly callout the methods run on remote hosts/devices.
PiperOrigin-RevId: 319157806
parent
30bac445
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
1 deletion
+16
-1
official/core/base_task.py
official/core/base_task.py
+16
-1
No files found.
official/core/base_task.py
View file @
66cc634f
...
@@ -107,6 +107,7 @@ class Task(tf.Module):
...
@@ -107,6 +107,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions.
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
Args:
params: hyperparams to create input pipelines.
params: hyperparams to create input pipelines.
...
@@ -172,6 +173,8 @@ class Task(tf.Module):
...
@@ -172,6 +173,8 @@ class Task(tf.Module):
metrics
=
None
):
metrics
=
None
):
"""Does forward and backward.
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
Args:
inputs: a dictionary of input tensors.
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
model: the model, forward pass definition.
...
@@ -219,6 +222,8 @@ class Task(tf.Module):
...
@@ -219,6 +222,8 @@ class Task(tf.Module):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
"""Validatation step.
"""Validatation step.
With distribution strategies, this method runs on devices.
Args:
Args:
inputs: a dictionary of input tensors.
inputs: a dictionary of input tensors.
model: the keras.Model.
model: the keras.Model.
...
@@ -244,7 +249,17 @@ class Task(tf.Module):
...
@@ -244,7 +249,17 @@ class Task(tf.Module):
return
logs
return
logs
def
inference_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
):
def
inference_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step."""
"""Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return
model
(
inputs
,
training
=
False
)
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
,
step_logs
):
def
aggregate_logs
(
self
,
state
,
step_logs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment