Commit 66cc634f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update docstring to explicitly callout the methods run on remote hosts/devices.

PiperOrigin-RevId: 319157806
parent 30bac445
...@@ -107,6 +107,7 @@ class Task(tf.Module): ...@@ -107,6 +107,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions. """Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size. Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args: Args:
params: hyperparams to create input pipelines. params: hyperparams to create input pipelines.
...@@ -172,6 +173,8 @@ class Task(tf.Module): ...@@ -172,6 +173,8 @@ class Task(tf.Module):
metrics=None): metrics=None):
"""Does forward and backward. """Does forward and backward.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the model, forward pass definition. model: the model, forward pass definition.
...@@ -219,6 +222,8 @@ class Task(tf.Module): ...@@ -219,6 +222,8 @@ class Task(tf.Module):
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validatation step.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the keras.Model. model: the keras.Model.
...@@ -244,7 +249,17 @@ class Task(tf.Module): ...@@ -244,7 +249,17 @@ class Task(tf.Module):
return logs return logs
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs): def aggregate_logs(self, state, step_logs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment