"vscode:/vscode.git/clone" did not exist on "df399534011a41f5ce371741ba3102f60d3f6c41"
Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
......@@ -42,6 +42,7 @@ This repository provides a curated list of the GitHub repositories with machine
| [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
| [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Transformer-LT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) |
| [ELECTRA](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/ELECTRA) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/forum?id=r1xMH1BtvB) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster | [NVIDIA](https://github.com/NVIDIA) |
## Recommendation Systems
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the base task abstraction."""
import abc
from typing import Optional
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Sequence, Union
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experiment factory methods."""
from official.core import config_definitions as cfg
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A common dataset reader."""
import random
from typing import Any, Callable, Optional
......@@ -31,6 +30,10 @@ def _get_random_integer():
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum = _get_random_integer()
def __init__(self,
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
......@@ -137,7 +140,13 @@ class InputReader:
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name
if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None):
......@@ -165,7 +174,8 @@ class InputReader:
map_func=self._dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
num_parallel_calls=(self._cycle_length if self._cycle_length else
tf.data.experimental.AUTOTUNE),
deterministic=self._deterministic)
return dataset
......@@ -277,7 +287,30 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._postprocess_fn)
if self._enable_tf_data_service:
if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
input_context.num_input_pipelines)
base_consumer_index = input_context.input_pipeline_id * (
replicas_per_input_pipeline)
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
lambda x: x,
cycle_length=replicas_per_input_pipeline,
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry utility."""
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for registry."""
from __future__ import absolute_import
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks."""
from official.core import registry
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for train_ctl_lib."""
import json
import os
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training utils."""
import copy
import json
......@@ -109,7 +108,7 @@ class BestCheckpointExporter:
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.remove(file_to_remove)
checkpoint.save(self.best_ckpt_path)
checkpoint.write(self.best_ckpt_path)
@property
def best_ckpt_logs(self):
......@@ -218,6 +217,16 @@ def serialize_config(params: config_definitions.ExperimentConfig,
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_surfix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_surfix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64)
......
......@@ -30,7 +30,7 @@ from official.modeling.hyperparams import params_dict
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
......
......@@ -312,7 +312,7 @@ class ParamsDict(object):
def read_yaml_to_params_dict(file_path):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=yaml.FullLoader)
params_dict = yaml.load(f, Loader=yaml.SafeLoader)
return ParamsDict(params_dict)
......
......@@ -18,27 +18,26 @@ from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling.hyperparams import base_config
from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(base_config.Config):
class TaskRoutine(hyperparams.Config):
task_name: str = ""
task_config: cfg.TaskConfig = None
mixing_steps: int = 1
eval_steps: Optional[int] = None
task_weight: Optional[float] = None
task_weight: Optional[float] = 1.0
@dataclasses.dataclass
class MultiTaskConfig(base_config.Config):
class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: base_config.Config = None
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class MultiEvalExperimentConfig(base_config.Config):
class MultiEvalExperimentConfig(hyperparams.Config):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
......
......@@ -32,16 +32,16 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_mixing_steps: Optional[Dict[str, int]] = None,
task_weights: Optional[Dict[str, float]] = None,
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_mixing_steps: a dict of (task, mixing steps).
task_weights: a dict of (task, loss weight).
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
......@@ -62,31 +62,24 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_mixing_steps = task_mixing_steps or {}
self._task_mixing_steps = dict([
(name, self._task_mixing_steps.get(name, 1)) for name in self.tasks
])
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, None)) for name in self.tasks
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_mixing_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps
task_mixing_steps[task_name] = task_routine.mixing_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks,
task_mixing_steps=task_mixing_steps,
task_eval_steps=task_eval_steps,
task_weights=task_weights)
......@@ -97,12 +90,13 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_mixing_steps(self, task_name):
return self._task_mixing_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod
def create_optimizer(cls,
optimizer_config: OptimizationConfig,
......
......@@ -25,12 +25,16 @@ from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask
def run_experiment_wtih_multitask_eval(
def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task,
eval_tasks: multitask.MultiTask, mode: str,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: multitask.MultiTask,
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str) -> tf.keras.Model:
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) -> tf.keras.Model:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -41,6 +45,9 @@ def run_experiment_wtih_multitask_eval(
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
model: `tf.keras.Model` instance.
......@@ -92,9 +99,11 @@ def run_experiment_wtih_multitask_eval(
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
summary_interval=params.trainer.summary_interval)
summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
......@@ -121,4 +130,8 @@ def run_experiment_wtih_multitask_eval(
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
return model
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return model, {}
......@@ -40,7 +40,7 @@ def configure_optimizer(optimizer,
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment