"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "0375c800c767db2ef070cee1529d8a50f42d1042"
Commit 66cc634f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update docstring to explicitly callout the methods run on remote hosts/devices.

PiperOrigin-RevId: 319157806
parent 30bac445
...@@ -107,6 +107,7 @@ class Task(tf.Module): ...@@ -107,6 +107,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions. """Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size. Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args: Args:
params: hyperparams to create input pipelines. params: hyperparams to create input pipelines.
...@@ -172,6 +173,8 @@ class Task(tf.Module): ...@@ -172,6 +173,8 @@ class Task(tf.Module):
metrics=None): metrics=None):
"""Does forward and backward. """Does forward and backward.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the model, forward pass definition. model: the model, forward pass definition.
...@@ -219,6 +222,8 @@ class Task(tf.Module): ...@@ -219,6 +222,8 @@ class Task(tf.Module):
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validatation step.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the keras.Model. model: the keras.Model.
...@@ -244,7 +249,17 @@ class Task(tf.Module): ...@@ -244,7 +249,17 @@ class Task(tf.Module):
return logs return logs
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs): def aggregate_logs(self, state, step_logs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment