Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
...@@ -42,6 +42,7 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -42,6 +42,7 @@ This repository provides a curated list of the GitHub repositories with machine
| [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
| [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) | | [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Transformer-LT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) | | [Transformer-LT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) |
| [ELECTRA](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/ELECTRA) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/forum?id=r1xMH1BtvB) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster | [NVIDIA](https://github.com/NVIDIA) |
## Recommendation Systems ## Recommendation Systems
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Defines the base task abstraction.""" """Defines the base task abstraction."""
import abc import abc
from typing import Optional from typing import Optional
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Standard Trainer implementation. """Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and The base trainer implements the Orbit `StandardTrainable` and
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer.""" """Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
import os import os
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Experiment factory methods.""" """Experiment factory methods."""
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
...@@ -31,6 +30,10 @@ def _get_random_integer(): ...@@ -31,6 +30,10 @@ def _get_random_integer():
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum = _get_random_integer()
def __init__(self, def __init__(self,
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
...@@ -137,7 +140,13 @@ class InputReader: ...@@ -137,7 +140,13 @@ class InputReader:
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
def _shard_files_then_read( def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None): self, input_context: Optional[tf.distribute.InputContext] = None):
...@@ -165,7 +174,8 @@ class InputReader: ...@@ -165,7 +174,8 @@ class InputReader:
map_func=self._dataset_fn, map_func=self._dataset_fn,
cycle_length=self._cycle_length, cycle_length=self._cycle_length,
block_length=self._block_length, block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE, num_parallel_calls=(self._cycle_length if self._cycle_length else
tf.data.experimental.AUTOTUNE),
deterministic=self._deterministic) deterministic=self._deterministic)
return dataset return dataset
...@@ -277,12 +287,35 @@ class InputReader: ...@@ -277,12 +287,35 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._postprocess_fn) dataset = maybe_map_fn(dataset, self._postprocess_fn)
if self._enable_tf_data_service: if self._enable_tf_data_service and input_context:
dataset = dataset.apply( if self._enable_round_robin_tf_data_service:
tf.data.experimental.service.distribute( replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
processing_mode='parallel_epochs', input_context.num_input_pipelines)
service=self._tf_data_service_address, base_consumer_index = input_context.input_pipeline_id * (
job_name=self._tf_data_service_job_name)) replicas_per_input_pipeline)
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
lambda x: x,
cycle_length=replicas_per_input_pipeline,
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
if self._deterministic is not None: if self._deterministic is not None:
options = tf.data.Options() options = tf.data.Options()
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Registry utility.""" """Registry utility."""
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for registry.""" """Tests for registry."""
from __future__ import absolute_import from __future__ import absolute_import
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks.""" """A global factory to register and access all registered tasks."""
from official.core import registry from official.core import registry
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for train_ctl_lib.""" """Tests for train_ctl_lib."""
import json import json
import os import os
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Training utils.""" """Training utils."""
import copy import copy
import json import json
...@@ -109,7 +108,7 @@ class BestCheckpointExporter: ...@@ -109,7 +108,7 @@ class BestCheckpointExporter:
# Saving the best checkpoint might be interrupted if the job got killed. # Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'): for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.remove(file_to_remove) tf.io.gfile.remove(file_to_remove)
checkpoint.save(self.best_ckpt_path) checkpoint.write(self.best_ckpt_path)
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
...@@ -218,6 +217,16 @@ def serialize_config(params: config_definitions.ExperimentConfig, ...@@ -218,6 +217,16 @@ def serialize_config(params: config_definitions.ExperimentConfig,
hyperparams.save_params_dict_to_yaml(params, params_save_path) hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_surfix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_surfix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path): def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename.""" """Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64) global_step = tf.Variable(-1, dtype=tf.int64)
......
...@@ -30,7 +30,7 @@ from official.modeling.hyperparams import params_dict ...@@ -30,7 +30,7 @@ from official.modeling.hyperparams import params_dict
class Config(params_dict.ParamsDict): class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides. """The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so * It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types. it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences, * It converts dict to Config even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]), e.g. for config = Config({'key': [([{'a': 42}],)]),
......
...@@ -312,7 +312,7 @@ class ParamsDict(object): ...@@ -312,7 +312,7 @@ class ParamsDict(object):
def read_yaml_to_params_dict(file_path): def read_yaml_to_params_dict(file_path):
"""Reads a YAML file to a ParamsDict.""" """Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=yaml.FullLoader) params_dict = yaml.load(f, Loader=yaml.SafeLoader)
return ParamsDict(params_dict) return ParamsDict(params_dict)
......
...@@ -18,27 +18,26 @@ from typing import Optional, Tuple ...@@ -18,27 +18,26 @@ from typing import Optional, Tuple
import dataclasses import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling.hyperparams import base_config from official.modeling import hyperparams
@dataclasses.dataclass @dataclasses.dataclass
class TaskRoutine(base_config.Config): class TaskRoutine(hyperparams.Config):
task_name: str = "" task_name: str = ""
task_config: cfg.TaskConfig = None task_config: cfg.TaskConfig = None
mixing_steps: int = 1
eval_steps: Optional[int] = None eval_steps: Optional[int] = None
task_weight: Optional[float] = None task_weight: Optional[float] = 1.0
@dataclasses.dataclass @dataclasses.dataclass
class MultiTaskConfig(base_config.Config): class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: base_config.Config = None model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = () task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass @dataclasses.dataclass
class MultiEvalExperimentConfig(base_config.Config): class MultiEvalExperimentConfig(hyperparams.Config):
"""An experiment config for single-task training and multi-task evaluation. """An experiment config for single-task training and multi-task evaluation.
Attributes: Attributes:
......
...@@ -32,16 +32,16 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -32,16 +32,16 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]], tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_mixing_steps: Optional[Dict[str, int]] = None, task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_weights: Optional[Dict[str, float]] = None,
task_eval_steps: Optional[Dict[str, int]] = None, task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None): name: Optional[str] = None):
"""MultiTask initialization. """MultiTask initialization.
Args: Args:
tasks: a list or a flat dict of Task. tasks: a list or a flat dict of Task.
task_mixing_steps: a dict of (task, mixing steps). task_weights: a dict of (task, task weight), task weight can be applied
task_weights: a dict of (task, loss weight). directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps). task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object. name: the instance name of a MultiTask object.
""" """
...@@ -62,31 +62,24 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -62,31 +62,24 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self._task_eval_steps = dict([ self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks (name, self._task_eval_steps.get(name, None)) for name in self.tasks
]) ])
self._task_mixing_steps = task_mixing_steps or {}
self._task_mixing_steps = dict([
(name, self._task_mixing_steps.get(name, 1)) for name in self.tasks
])
self._task_weights = task_weights or {} self._task_weights = task_weights or {}
self._task_weights = dict([ self._task_weights = dict([
(name, self._task_weights.get(name, None)) for name in self.tasks (name, self._task_weights.get(name, 1.0)) for name in self.tasks
]) ])
@classmethod @classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None): def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {} tasks = {}
task_eval_steps = {} task_eval_steps = {}
task_mixing_steps = {}
task_weights = {} task_weights = {}
for task_routine in config.task_routines: for task_routine in config.task_routines:
task_name = task_routine.task_name task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task( tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir) task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps task_eval_steps[task_name] = task_routine.eval_steps
task_mixing_steps[task_name] = task_routine.mixing_steps
task_weights[task_name] = task_routine.task_weight task_weights[task_name] = task_routine.task_weight
return cls( return cls(
tasks, tasks,
task_mixing_steps=task_mixing_steps,
task_eval_steps=task_eval_steps, task_eval_steps=task_eval_steps,
task_weights=task_weights) task_weights=task_weights)
...@@ -97,12 +90,13 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -97,12 +90,13 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def task_eval_steps(self, task_name): def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name] return self._task_eval_steps[task_name]
def task_mixing_steps(self, task_name):
return self._task_mixing_steps[task_name]
def task_weight(self, task_name): def task_weight(self, task_name):
return self._task_weights[task_name] return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod @classmethod
def create_optimizer(cls, def create_optimizer(cls,
optimizer_config: OptimizationConfig, optimizer_config: OptimizationConfig,
......
...@@ -25,12 +25,16 @@ from official.modeling.multitask import evaluator as evaluator_lib ...@@ -25,12 +25,16 @@ from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
def run_experiment_wtih_multitask_eval( def run_experiment_with_multitask_eval(
*, *,
distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task, distribution_strategy: tf.distribute.Strategy,
eval_tasks: multitask.MultiTask, mode: str, train_task: base_task.Task,
eval_tasks: multitask.MultiTask,
mode: str,
params: configs.MultiEvalExperimentConfig, params: configs.MultiEvalExperimentConfig,
model_dir: str) -> tf.keras.Model: model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) -> tf.keras.Model:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -41,6 +45,9 @@ def run_experiment_wtih_multitask_eval( ...@@ -41,6 +45,9 @@ def run_experiment_wtih_multitask_eval(
or 'continuous_eval'. or 'continuous_eval'.
params: MultiEvalExperimentConfig instance. params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns: Returns:
model: `tf.keras.Model` instance. model: `tf.keras.Model` instance.
...@@ -92,9 +99,11 @@ def run_experiment_wtih_multitask_eval( ...@@ -92,9 +99,11 @@ def run_experiment_wtih_multitask_eval(
global_step=global_step, global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop, steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'), summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
eval_summary_dir=os.path.join(model_dir, 'validation'), eval_summary_dir=os.path.join(model_dir, 'validation') if
summary_interval=params.trainer.summary_interval) (save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope(): with distribution_strategy.scope():
...@@ -121,4 +130,8 @@ def run_experiment_wtih_multitask_eval( ...@@ -121,4 +130,8 @@ def run_experiment_wtih_multitask_eval(
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
return model if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return model, {}
...@@ -40,7 +40,7 @@ def configure_optimizer(optimizer, ...@@ -40,7 +40,7 @@ def configure_optimizer(optimizer,
optimizer, dynamic=False, initial_scale=loss_scale) optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite: if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure # Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and # tf.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up. # up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment