Commit ea646b04 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322707701
parent 7fcd7cba
......@@ -17,11 +17,33 @@
import abc
from typing import Any, Dict, Optional, Text
import dataclasses
from orbit import runner
from orbit import utils
import tensorflow as tf
@dataclasses.dataclass(frozen=True)
class TrainerOverrides:
"""Advanced overrides for Orbit trainers.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
use_tpu_summary_optimization: bool = False
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""
......@@ -139,6 +161,17 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self._train_iter = None
@dataclasses.dataclass(frozen=True)
class EvaluatorOverrides:
"""Advanced overrides for Orbit evaluators.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
"""
use_tf_function: bool = False
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment