"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "40e5cb7a9c0e3ff5ef711e1cee0bdbe9714bcf97"
Commit 71d2680d authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309486898
parent 10b38209
...@@ -134,17 +134,6 @@ class SummaryWriter(object): ...@@ -134,17 +134,6 @@ class SummaryWriter(object):
class DistributedExecutor(object): class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy. """Interface to train and eval models with tf.distribute.Strategy.
Arguments:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
""" """
def __init__(self, def __init__(self,
...@@ -153,6 +142,18 @@ class DistributedExecutor(object): ...@@ -153,6 +142,18 @@ class DistributedExecutor(object):
model_fn, model_fn,
loss_fn, loss_fn,
is_multi_host=False): is_multi_host=False):
"""Constructor.
Args:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
"""
self._params = params self._params = params
self._model_fn = model_fn self._model_fn = model_fn
...@@ -224,6 +225,18 @@ class DistributedExecutor(object): ...@@ -224,6 +225,18 @@ class DistributedExecutor(object):
loss_fn, loss_fn,
optimizer, optimizer,
metric=None): metric=None):
"""Creates a single training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
metrics = metrics_as_dict(metric) metrics = metrics_as_dict(metric)
def _replicated_step(inputs): def _replicated_step(inputs):
...@@ -252,19 +265,18 @@ class DistributedExecutor(object): ...@@ -252,19 +265,18 @@ class DistributedExecutor(object):
metric=None): metric=None):
"""Creates a distributed training step. """Creates a distributed training step.
Args: Args:
strategy: an instance of tf.distribute.Strategy. strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function. model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor. loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer. optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors. metric: tf.keras.metrics.Metric subclass.
metric: tf.keras.metrics.Metric subclass.
Returns: Returns:
The training step callable. The training step callable.
""" """
_replicated_step = self._create_replicated_step(strategy, model, loss_fn, replicated_step = self._create_replicated_step(strategy, model, loss_fn,
optimizer, metric) optimizer, metric)
@tf.function @tf.function
def train_step(iterator, num_steps): def train_step(iterator, num_steps):
...@@ -282,10 +294,10 @@ class DistributedExecutor(object): ...@@ -282,10 +294,10 @@ class DistributedExecutor(object):
'retracing.') 'retracing.')
per_replica_losses = strategy.run( per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),)) replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1): for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run( per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),)) replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses. # For reporting, we returns the mean of losses.
losses = tf.nest.map_structure( losses = tf.nest.map_structure(
...@@ -318,7 +330,6 @@ class DistributedExecutor(object): ...@@ -318,7 +330,6 @@ class DistributedExecutor(object):
return test_step return test_step
def train(self, def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset], train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict], eval_input_fn: Callable[[params_dict.ParamsDict],
...@@ -404,6 +415,7 @@ class DistributedExecutor(object): ...@@ -404,6 +415,7 @@ class DistributedExecutor(object):
train_iterator = self._get_input_iterator(train_input_fn, strategy) train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None train_loss = None
eval_metric_result = None eval_metric_result = None
tf.keras.backend.set_learning_phase(1)
with strategy.scope(): with strategy.scope():
# To correctly place the model weights on accelerators, # To correctly place the model weights on accelerators,
# model and optimizer should be created in scope. # model and optimizer should be created in scope.
...@@ -584,10 +596,10 @@ class DistributedExecutor(object): ...@@ -584,10 +596,10 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on model folder. """Runs distributed evaluation on model folder.
Args: Args:
model_dir: the folder for storing model checkpoints.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step. trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step. eval_metric_fn: metric_fn for evaluation in test_step.
model_dir: the folder for storing model checkpoints.
total_steps: total training steps. If the current step reaches the total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop. total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints. eval_timeout: The maximum number of seconds to wait between checkpoints.
...@@ -638,11 +650,11 @@ class DistributedExecutor(object): ...@@ -638,11 +650,11 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on the one checkpoint. """Runs distributed evaluation on the one checkpoint.
Args: Args:
checkpoint_path: the checkpoint to evaluate.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step. trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step. eval_metric_fn: metric_fn for evaluation in test_step.
checkpoint_path: the checkpoint to evaluate. summary_writer: function to create summary writer.
summary_writer_fn: function to create summary writer.
Returns: Returns:
Eval metrics dictionary of the last checkpoint. Eval metrics dictionary of the last checkpoint.
...@@ -651,6 +663,8 @@ class DistributedExecutor(object): ...@@ -651,6 +663,8 @@ class DistributedExecutor(object):
raise ValueError('if `eval_metric_fn` is specified, ' raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.') 'eval_metric_fn must be a callable.')
old_phrase = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0)
params = self._params params = self._params
strategy = self._strategy strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place # To reduce unnecessary send/receive input pipeline operation, we place
...@@ -686,6 +700,7 @@ class DistributedExecutor(object): ...@@ -686,6 +700,7 @@ class DistributedExecutor(object):
summary_writer(metrics=eval_metric_result, step=current_step) summary_writer(metrics=eval_metric_result, step=current_step)
reset_states(eval_metric) reset_states(eval_metric)
tf.keras.backend.set_learning_phase(old_phrase)
return eval_metric_result, current_step return eval_metric_result, current_step
def predict(self): def predict(self):
...@@ -726,18 +741,20 @@ class ExecutorBuilder(object): ...@@ -726,18 +741,20 @@ class ExecutorBuilder(object):
model_fn=my_model_fn, model_fn=my_model_fn,
loss_fn=my_loss_fn, loss_fn=my_loss_fn,
metric_fn=my_metric_fn) metric_fn=my_metric_fn)
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If
None. User is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster( _ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index) strategy_config.worker_hosts, strategy_config.task_index)
"""Constructor.
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
If None. User is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
self._strategy = distribution_utils.get_distribution_strategy( self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type, distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus, num_gpus=strategy_config.num_gpus,
...@@ -755,7 +772,6 @@ class ExecutorBuilder(object): ...@@ -755,7 +772,6 @@ class ExecutorBuilder(object):
"""Sets default summary writer for the current thread.""" """Sets default summary writer for the current thread."""
self._strategy = new_strategy self._strategy = new_strategy
def build_executor(self, def build_executor(self,
class_ctor=DistributedExecutor, class_ctor=DistributedExecutor,
params=None, params=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment