"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "b8a212ef6bd9e79517a36b0a3094806585da0f38"
Commit 59353708 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix typo and grammar errors.

PiperOrigin-RevId: 319276332
parent b43418ad
...@@ -59,7 +59,7 @@ class Task(tf.Module): ...@@ -59,7 +59,7 @@ class Task(tf.Module):
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir. checkpoint, saved under a directory other than the model_dir.
...@@ -71,7 +71,7 @@ class Task(tf.Module): ...@@ -71,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -135,7 +135,7 @@ class Task(tf.Module): ...@@ -135,7 +135,7 @@ class Task(tf.Module):
Args: Args:
labels: optional label tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -232,7 +232,7 @@ class Task(tf.Module): ...@@ -232,7 +232,7 @@ class Task(tf.Module):
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validation step.
With distribution strategies, this method runs on devices. With distribution strategies, this method runs on devices.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment