Commit cf80ed4e authored by anivegesana's avatar anivegesana
Browse files

Merge branch 'purdue-yolo' of https://github.com/tensorflow/models into detection_generator_pr_2

parents 394cefcc 461b3587
...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config): ...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: base_config.Config = None model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -142,14 +142,19 @@ class BestCheckpointExporter: ...@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return self._checkpoint_manager return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step): def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step) eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better( if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs): self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs, if write_logs:
global_step) self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self): def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path): if not tf.io.gfile.exists(self.best_ckpt_logs_path):
...@@ -180,7 +185,7 @@ class BestCheckpointExporter: ...@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return True return True
return False return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step): def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file.""" """Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs) eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step eval_logs_ext['best_ckpt_global_step'] = global_step
...@@ -190,8 +195,6 @@ class BestCheckpointExporter: ...@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
self._get_checkpoint_manager(checkpoint).save()
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
return self._best_ckpt_logs return self._best_ckpt_logs
...@@ -377,11 +380,15 @@ def remove_ckpts(model_dir): ...@@ -377,11 +380,15 @@ def remove_ckpts(model_dir):
tf.io.gfile.remove(file_to_remove) tf.io.gfile.remove(file_to_remove)
def try_count_params(model: tf.keras.Model): def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible. """Count the number of parameters if model is possible.
Args: Args:
model: Try to count the number of params in this model. model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns: Returns:
The number of parameters or None. The number of parameters or None.
...@@ -395,7 +402,13 @@ def try_count_params(model: tf.keras.Model): ...@@ -395,7 +402,13 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max ' 'because the model was not feed any input, e.g., the max '
'train step already reached before this run.') 'train step already reached before this run.')
return None return None
return None else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model], def try_count_flops(model: Union[tf.Module, tf.keras.Model],
......
...@@ -23,6 +23,7 @@ from official.modeling import hyperparams ...@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@dataclasses.dataclass @dataclasses.dataclass
class TaskRoutine(hyperparams.Config): class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = "" task_name: str = ""
task_config: cfg.TaskConfig = None task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None eval_steps: Optional[int] = None
...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig): ...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes: Attributes:
eval_tasks: individual evaluation tasks. eval_tasks: individual evaluation tasks.
""" """
eval_tasks: MultiTaskConfig = MultiTaskConfig() eval_tasks: Tuple[TaskRoutine, ...] = ()
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface. The evaluator implements the Orbit `AbstractEvaluator` interface.
""" """
from typing import Optional, Union from typing import Dict, List, Optional, Union
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task
from official.core import train_utils from official.core import train_utils
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable @gin.configurable
...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def __init__( def __init__(
self, self,
task: multitask.MultiTask, eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel], model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None, global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None): checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
task: A multitask.MultiTask instance. eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance. model: tf.keras.Model instance.
global_step: the global step variable. global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface. interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._task = task self._tasks = eval_tasks
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, global_step=self.global_step, model=self.model)
model=self.model)
self._validation_losses = None self._validation_losses = None
self._validation_metrics = None self._validation_metrics = None
# Builds per-task datasets. # Builds per-task datasets.
self.eval_datasets = {} self.eval_datasets = {}
for name, task in self.task.tasks.items(): self.eval_steps = eval_steps or {}
self.eval_datasets[name] = orbit.utils.make_distributed_dataset( for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data) self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops. # Builds per-task validation loops.
...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return orbit.utils.create_loop_fn(eval_step_fn) return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = { self.task_fns = {
name: get_function(name, task) task.name: get_function(task.name, task) for task in self.tasks
for name, task in self.task.tasks.items()
} }
@property @property
...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return self._strategy return self._strategy
@property @property
def task(self): def tasks(self):
return self._task return self._tasks
@property @property
def model(self): def model(self):
...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_losses is None: if self._validation_losses is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_losses = {} self._validation_losses = {}
for name in self.task.tasks: for task in self.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean( self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32) "validation_loss", dtype=tf.float32)
return self._validation_losses return self._validation_losses
...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_metrics is None: if self._validation_metrics is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_metrics = {} self._validation_metrics = {}
for name, task in self.task.tasks.items(): for task in self.tasks:
self._validation_metrics[name] = task.build_metrics(training=False) self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics return self._validation_metrics
@property @property
...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results = {} results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets) eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items(): for task in self.tasks:
outputs = None outputs = None
name = task.name
eval_iter = eval_iters[name] eval_iter = eval_iters[name]
task = self.task.tasks[name] task_eval_steps = self.eval_steps.get(name, None) or num_steps
task_eval_steps = self.task.task_eval_steps(name) or num_steps outputs = self.task_fns[name](
outputs = task_eval_loop(
eval_iter, eval_iter,
task_eval_steps, task_eval_steps,
state=outputs, state=outputs,
......
...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations ...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations(): def all_strategy_combinations():
...@@ -89,9 +88,7 @@ class MockTask(base_task.Task): ...@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value])) np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state return state
def reduce_aggregated_logs(self, def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items(): for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0)) aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs return aggregated_logs
...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"], self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync) 5. * distribution.num_replicas_in_sync)
......
...@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer: tf.optimizers.Optimizer, optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler, task_sampler: sampler.TaskSampler,
trainer_options=None): trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__( super().__init__(
multi_task=multi_task, multi_task=multi_task,
multi_task_model=multi_task_model, multi_task_model=multi_task_model,
optimizer=optimizer, optimizer=optimizer,
...@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self._task_train_step_map[name], args=(next(iterator_map[name]),)) self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1) self.task_step_counter(name).assign_add(1)
def train_loop_end(self):
"""Record loss and metric values per task."""
result = super().train_loop_end()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if 'total_loss' in result:
result.pop('total_loss')
return result
...@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results["bar"].keys()) results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"], self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys()) results["foo"].keys())
self.assertNotIn("total_loss", results)
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution): def test_trainer_with_configs(self, distribution):
......
...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else: else:
raise ValueError("The tasks argument has an invalid type: %s" % raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks)) type(tasks))
self._task_eval_steps = task_eval_steps or {} self.task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_weights = task_weights or {} self._task_weights = task_weights or {}
self._task_weights = dict([ self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks (name, self._task_weights.get(name, 1.0)) for name in self.tasks
...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps = {} task_eval_steps = {}
task_weights = {} task_weights = {}
for task_routine in config.task_routines: for task_routine in config.task_routines:
task_name = task_routine.task_name task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task( tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir) task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight task_weights[task_name] = task_routine.task_weight
return cls( return cls(
...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def tasks(self): def tasks(self):
return self._tasks return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name): def task_weight(self, task_name):
return self._task_weights[task_name] return self._task_weights[task_name]
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Optional from typing import List, Optional
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy, ...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer = TRAINERS[params.trainer.trainer_type]( trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None **kwargs) if is_training else None
if is_eval: if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=task, eval_tasks=task.tasks.values(),
model=model, model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval( ...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*, *,
distribution_strategy: tf.distribute.Strategy, distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task, train_task: base_task.Task,
eval_tasks: multitask.MultiTask, eval_tasks: List[base_task.Task],
mode: str, mode: str,
params: configs.MultiEvalExperimentConfig, params: configs.MultiEvalExperimentConfig,
model_dir: str, model_dir: str,
...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval( ...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args: Args:
distribution_strategy: A distribution distribution_strategy. distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance. train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks. eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'. or 'continuous_eval'.
params: MultiEvalExperimentConfig instance. params: MultiEvalExperimentConfig instance.
...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval( ...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config=params, config=params,
task=train_task, task=train_task,
model=train_task.build_model(), model=train_task.build_model(),
optimizer=train_task.create_optimizer( optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.trainer.optimizer_config, params.runtime), params.runtime),
train=True, train=True,
evaluate=False) evaluate=False)
else: else:
...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval( ...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model = trainer.model if trainer else train_task.build_model() model = trainer.model if trainer else train_task.build_model()
if is_eval: if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks, eval_tasks=eval_tasks,
model=model, model=model,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
else: else:
......
...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task=configs.MultiTaskConfig( task=configs.MultiTaskConfig(
task_routines=( task_routines=(
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='foo', task_config=test_utils.FooConfig()),
task_config=test_utils.FooConfig()),
configs.TaskRoutine( configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig())))) task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig( experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(), task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig( eval_tasks=(configs.TaskRoutine(
task_routines=( task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='bar',
task_config=test_utils.FooConfig()), task_config=test_utils.BarConfig(),
configs.TaskRoutine( eval_steps=3)))
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope(): with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task) train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks) eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval( train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
train_task=train_task, train_task=train_task,
......
...@@ -28,7 +28,6 @@ from official.core import train_lib ...@@ -28,7 +28,6 @@ from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib from official.modeling.multitask import train_lib as multitask_train_lib
...@@ -167,7 +166,10 @@ def run_continuous_finetune( ...@@ -167,7 +166,10 @@ def run_continuous_finetune(
with distribution_strategy.scope(): with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig): if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task) task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks) eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in params.eval_tasks
]
(_, (_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval( eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
......
...@@ -89,8 +89,7 @@ def _get_ngrams_with_counter(segment, max_order): ...@@ -89,8 +89,7 @@ def _get_ngrams_with_counter(segment, max_order):
Args: Args:
segment: text segment from which n-grams will be extracted. segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this max_order: maximum length in tokens of the n-grams returned by this methods.
methods.
Returns: Returns:
The Counter containing all n-grams upto max_order in segment The Counter containing all n-grams upto max_order in segment
...@@ -104,15 +103,17 @@ def _get_ngrams_with_counter(segment, max_order): ...@@ -104,15 +103,17 @@ def _get_ngrams_with_counter(segment, max_order):
return ngram_counts return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4, def compute_bleu(reference_corpus,
translation_corpus,
max_order=4,
use_bp=True): use_bp=True):
"""Computes BLEU score of translated segments against one or more references. """Computes BLEU score of translated segments against one or more references.
Args: Args:
reference_corpus: list of references for each translation. Each reference_corpus: list of references for each translation. Each reference
reference should be tokenized into a list of tokens. should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation translation_corpus: list of translations to score. Each translation should
should be tokenized into a list of tokens. be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score. max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty. use_bp: boolean, whether to apply brevity penalty.
...@@ -134,15 +135,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, ...@@ -134,15 +135,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
ref_ngram_counts = _get_ngrams_with_counter(references, max_order) ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order) translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram, overlap = dict((ngram, min(count, translation_ngram_counts[ngram]))
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items()) for ngram, count in ref_ngram_counts.items())
for ngram in overlap: for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram] matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts: for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[ possible_matches_by_order[len(ngram) -
ngram] 1] += translation_ngram_counts[ngram]
precisions = [0] * max_order precisions = [0] * max_order
smooth = 1.0 smooth = 1.0
...@@ -151,8 +151,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, ...@@ -151,8 +151,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if possible_matches_by_order[i] > 0: if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i] precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0: if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[ precisions[i] = float(
i] matches_by_order[i]) / possible_matches_by_order[i]
else: else:
smooth *= 2 smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i]) precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
...@@ -165,7 +165,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, ...@@ -165,7 +165,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if use_bp: if use_bp:
ratio = translation_length / reference_length ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0 bp = 0. if ratio < 1e-6 else math.exp(1 -
1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp bleu = geo_mean * bp
return np.float32(bleu) return np.float32(bleu)
......
...@@ -22,3 +22,4 @@ from official.nlp.modeling import layers ...@@ -22,3 +22,4 @@ from official.nlp.modeling import layers
from official.nlp.modeling import losses from official.nlp.modeling import losses
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling import ops
...@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer): ...@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer):
return output return output
@tf.keras.utils.register_keras_serializable(package='Text')
class NoNormClipped(NoNorm):
"""Quantization friendly implementation for the NoNorm.
The output of NoNorm layer is clipped to [-6.0, 6.0] to make it quantization
friendly.
"""
def __init__(self, name=None):
super(NoNormClipped, self).__init__(name=name)
def call(self, feature):
output = feature * self.scale + self.bias
clipped_output = tf.clip_by_value(output, -6.0, 6.0)
return clipped_output
def _get_norm_layer(normalization_type='no_norm', name=None): def _get_norm_layer(normalization_type='no_norm', name=None):
"""Get normlization layer. """Get normlization layer.
...@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None): ...@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
""" """
if normalization_type == 'no_norm': if normalization_type == 'no_norm':
layer = NoNorm(name=name) layer = NoNorm(name=name)
elif normalization_type == 'no_norm_clipped':
layer = NoNormClipped(name=name)
elif normalization_type == 'layer_norm': elif normalization_type == 'layer_norm':
layer = tf.keras.layers.LayerNormalization( layer = tf.keras.layers.LayerNormalization(
name=name, name=name,
......
...@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0): ...@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input return fake_input
class EdgeTPUNoNormTest(tf.test.TestCase):
def test_no_norm(self):
layer = mobile_bert_layers.NoNormClipped()
feature = tf.random.uniform(
[2, 3, 4], minval=-8, maxval=8, dtype=tf.float32)
output = layer(feature)
output_shape = output.shape.as_list()
expected_shape = [2, 3, 4]
self.assertListEqual(output_shape, expected_shape, msg=None)
output_min = tf.reduce_min(output)
output_max = tf.reduce_max(output)
self.assertGreaterEqual(6.0, output_max)
self.assertLessEqual(-6.0, output_min)
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self): def test_embedding_layer_with_token_type(self):
......
...@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
def call(self, inputs, *, training=None): def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training training = self.do_power_iteration if training is None else training
u_update_op, v_update_op, w_update_op = self.update_weights( if training:
training=training) u_update_op, v_update_op, w_update_op = self.update_weights(
output = self.layer(inputs) training=training)
w_restore_op = self.restore_weights() output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op) # Register update ops.
self.add_update(v_update_op) self.add_update(u_update_op)
self.add_update(w_update_op) self.add_update(v_update_op)
self.add_update(w_restore_op) self.add_update(w_update_op)
self.add_update(w_restore_op)
else:
output = self.layer(inputs)
return output return output
......
...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model): ...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder. encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head. cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits. It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate', If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored. 'use_encoder_pooler', 'head_name') will be ignored.
""" """
def __init__(self, def __init__(self,
...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model): ...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None, cls_head=None,
**kwargs): **kwargs):
self.num_classes = num_classes self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler self.use_encoder_pooler = use_encoder_pooler
...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model): ...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name=head_name)
predictions = classifier(cls_inputs) predictions = classifier(cls_inputs)
...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model): ...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return { return {
'network': self._network, 'network': self._network,
'num_classes': self.num_classes, 'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer, 'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler, 'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head, 'cls_head': self._cls_head,
......
...@@ -87,10 +87,8 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -87,10 +87,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
inner_dim=0, num_classes=4))) inner_dim=0, num_classes=4)))
def test_serialize_deserialize(self, cls_head): def test_serialize_deserialize(self, cls_head):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -67,10 +67,8 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -67,10 +67,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainer( bert_trainer_model = bert_pretrainer.BertPretrainer(
...@@ -213,10 +211,8 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -213,10 +211,8 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
def test_v2_serialize_deserialize(self): def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -93,10 +93,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -93,10 +93,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment