Commit b0ccdb11 authored by Shixin Luo's avatar Shixin Luo
Browse files

resolve conflict with master

parents e61588cd 1611a8c5
...@@ -11,24 +11,15 @@ ...@@ -11,24 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ============================================================================
"""Initializes TPU system for TF 2.0.""" #
# THIS IS A GENERATED DOCKERFILE.
import tensorflow as tf #
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
def tpu_initialize(tpu_address): # for more information.
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns: # A list of assignees
A TPUClusterResolver. assignees:
""" - saikumarchalla
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - ravikyram
tpu=tpu_address)
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
...@@ -35,6 +35,20 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -35,6 +35,20 @@ This repository provides a curated list of the GitHub repositories with machine
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
| [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
## Natural Language Processing
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
| [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Transformer-LT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) |
## Recommendation Systems
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [Wide & Deep](https://github.com/IntelAI/models/tree/master/benchmarks/recommendation/tensorflow/wide_deep_large_ds) | [Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""
import json
import os
import random
import string
from absl import logging
import tensorflow as tf
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.x training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None,
**kwargs):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
**kwargs: Additional kwargs for internal usages.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
del kwargs
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
...@@ -14,32 +14,28 @@ ...@@ -14,32 +14,28 @@
# ============================================================================== # ==============================================================================
""" Tests for distribution util functions.""" """ Tests for distribution util functions."""
from __future__ import absolute_import import tensorflow as tf
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf from official.common import distribute_utils
from official.utils.misc import distribution_utils
class GetDistributionStrategyTest(tf.test.TestCase): class GetDistributionStrategyTest(tf.test.TestCase):
"""Tests for get_distribution_strategy.""" """Tests for get_distribution_strategy."""
def test_one_device_strategy_cpu(self): def test_one_device_strategy_cpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=0) ds = distribute_utils.get_distribution_strategy(num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0]) self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self): def test_one_device_strategy_gpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=1) ds = distribute_utils.get_distribution_strategy(num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0]) self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self): def test_mirrored_strategy(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=5) ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5) self.assertEquals(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices: for device in ds.extended.worker_devices:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
...@@ -33,7 +33,6 @@ class InputReader: ...@@ -33,7 +33,6 @@ class InputReader:
def __init__(self, def __init__(self,
params: cfg.DataConfig, params: cfg.DataConfig,
shards: Optional[List[str]] = None,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
...@@ -45,8 +44,6 @@ class InputReader: ...@@ -45,8 +44,6 @@ class InputReader:
Args: Args:
params: A config_definitions.DataConfig object. params: A config_definitions.DataConfig object.
shards: A list of files to be read. If given, read from these files.
Otherwise, read from params.input_path.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
...@@ -56,36 +53,54 @@ class InputReader: ...@@ -56,36 +53,54 @@ class InputReader:
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be input, and returns a `tf.data.Dataset` object. It will be executed after
executed after `parser_fn` to transform and batch the dataset; if None, `parser_fn` to transform and batch the dataset; if None, after
after `parser_fn` is executed, the dataset will be batched into `parser_fn` is executed, the dataset will be batched into per-replica
per-replica batch size. batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
if params.input_path and params.tfds_name: if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be ' raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % ( 'specified, but got %s and %s.' %
params.input_path, params.tfds_name)) (params.input_path, params.tfds_name))
self._shards = shards
self._tfds_builder = None self._tfds_builder = None
if self._shards: self._matched_files = []
self._num_files = len(self._shards) if params.input_path:
elif not params.tfds_name: # Read dataset from files.
self._input_patterns = params.input_path.strip().split(',') usage = ('`input_path` should be either (1) a str indicating a file '
self._num_files = 0 'path/pattern, or (2) a str indicating multiple file '
for input_pattern in self._input_patterns: 'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
input_pattern = input_pattern.strip() '"a,b,c", or (3) a list of str, each of which is a file '
if not input_pattern: 'path/pattern or multiple file paths/patterns separated by '
continue 'comma, but got: %s')
matched_files = tf.io.gfile.glob(input_pattern) if isinstance(params.input_path, str):
if not matched_files: input_path_list = [params.input_path]
raise ValueError('%s does not match any files.' % input_pattern) elif isinstance(params.input_path, (list, tuple)):
else: if any(not isinstance(x, str) for x in params.input_path):
self._num_files += len(matched_files) raise ValueError(usage % params.input_path)
if self._num_files == 0: input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path) raise ValueError('%s does not match any files.' % params.input_path)
else: else:
# Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
raise ValueError( raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' % '`tfds_name` is %s, but `tfds_split` is not specified.' %
...@@ -102,7 +117,6 @@ class InputReader: ...@@ -102,7 +117,6 @@ class InputReader:
self._block_length = params.block_length self._block_length = params.block_length
self._deterministic = params.deterministic self._deterministic = params.deterministic
self._sharding = params.sharding self._sharding = params.sharding
self._examples_consume = params.examples_consume
self._tfds_split = params.tfds_split self._tfds_split = params.tfds_split
self._tfds_download = params.tfds_download self._tfds_download = params.tfds_download
self._tfds_as_supervised = params.tfds_as_supervised self._tfds_as_supervised = params.tfds_as_supervised
...@@ -120,23 +134,16 @@ class InputReader: ...@@ -120,23 +134,16 @@ class InputReader:
self._tf_data_service_address = params.tf_data_service_address self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files( def _read_sharded_files(self,
self, input_context: Optional[
input_context: Optional[tf.distribute.InputContext] = None): tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files.""" """Reads a dataset from sharded files."""
# Read from `self._shards` if it is provided. dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
if self._shards:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns,
seed=self._seed,
shuffle=self._is_training)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
if self._shards and self._is_training: if self._is_training:
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(self._shards), len(self._matched_files),
seed=self._seed, seed=self._seed,
reshuffle_each_iteration=True) reshuffle_each_iteration=True)
...@@ -158,12 +165,12 @@ class InputReader: ...@@ -158,12 +165,12 @@ class InputReader:
deterministic=self._deterministic) deterministic=self._deterministic)
return dataset return dataset
def _read_single_file( def _read_single_file(self,
self, input_context: Optional[
input_context: Optional[tf.distribute.InputContext] = None): tf.distribute.InputContext] = None):
"""Reads a dataset from a single file.""" """Reads a dataset from a single file."""
# Read from `self._shards` if it is provided. # Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._shards or self._input_patterns) dataset = self._dataset_fn(self._matched_files)
# When `input_file` is a path to a single file, disable auto sharding # When `input_file` is a path to a single file, disable auto sharding
# so that same input file is sent to all workers. # so that same input file is sent to all workers.
...@@ -225,11 +232,13 @@ class InputReader: ...@@ -225,11 +232,13 @@ class InputReader:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
if self._tfds_builder: if self._tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif self._num_files > 1: elif len(self._matched_files) > 1:
dataset = self._read_sharded_files(input_context) dataset = self._read_sharded_files(input_context)
else: elif len(self._matched_files) == 1:
assert self._num_files == 1
dataset = self._read_single_file(input_context) dataset = self._read_single_file(input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
if self._cache: if self._cache:
dataset = dataset.cache() dataset = dataset.cache()
...@@ -237,9 +246,6 @@ class InputReader: ...@@ -237,9 +246,6 @@ class InputReader:
if self._is_training: if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
def maybe_map_fn(dataset, fn): def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map( return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error
import copy import copy
import json import json
import os import os
...@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif mode == 'eval': elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps) controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval': elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously( controller.evaluate_continuously(
steps=params.trainer.validation_steps, steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout) timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
......
...@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
'train_steps': 10, 'train_steps': 10,
'validation_steps': 5, 'validation_steps': 5,
'validation_interval': 10, 'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': { 'optimizer_config': {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
...@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(logs) self.assertEmpty(logs)
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml'))) tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode != 'eval': if flag_mode == 'eval':
self.assertNotEmpty( return
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
import json import json
import os import os
import pprint import pprint
from typing import Any from typing import Any, List
from absl import logging from absl import logging
import dataclasses
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -37,7 +38,7 @@ def create_trainer( ...@@ -37,7 +38,7 @@ def create_trainer(
model_dir: str, model_dir: str,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
checkpoint_exporter: Any = None): checkpoint_exporter: Any = None) -> base_trainer.Trainer:
"""Create trainer.""" """Create trainer."""
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
...@@ -47,6 +48,16 @@ def create_trainer( ...@@ -47,6 +48,16 @@ def create_trainer(
return trainer return trainer
@dataclasses.dataclass
class ParseConfigOptions:
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
experiment: str
config_file: List[str]
tpu: str = ''
tf_data_service: str = ''
params_override: str = ''
def parse_configuration(flags_obj): def parse_configuration(flags_obj):
"""Parses ExperimentConfig from flags.""" """Parses ExperimentConfig from flags."""
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Sequence, Union
import dataclasses import dataclasses
...@@ -30,9 +30,12 @@ class DataConfig(base_config.Config): ...@@ -30,9 +30,12 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets. """The base configuration for building datasets.
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or input_path: The path to the input. It can be either (1) a str indicating
(2) multiple file patterns separated by comma. It should not be specified a file path/pattern, or (2) a str indicating multiple file paths/patterns
when the following `tfds_name` is specified. separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified. specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It tfds_split: A str indicating which split of the data to load from TFDS. It
...@@ -50,10 +53,6 @@ class DataConfig(base_config.Config): ...@@ -50,10 +53,6 @@ class DataConfig(base_config.Config):
element before cycling to another input element when interleaving files. element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced. deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline. sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
enable_tf_data_service: A boolean indicating whether to enable tf.data enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline. service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload tf_data_service_address: The URI of a tf.data service to offload
...@@ -75,7 +74,7 @@ class DataConfig(base_config.Config): ...@@ -75,7 +74,7 @@ class DataConfig(base_config.Config):
features. The main use case is to skip the image/video decoding for better features. The main use case is to skip the image/video decoding for better
performance. performance.
""" """
input_path: str = "" input_path: Union[Sequence[str], str] = ""
tfds_name: str = "" tfds_name: str = ""
tfds_split: str = "" tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
...@@ -87,7 +86,6 @@ class DataConfig(base_config.Config): ...@@ -87,7 +86,6 @@ class DataConfig(base_config.Config):
block_length: int = 1 block_length: int = 1
deterministic: Optional[bool] = None deterministic: Optional[bool] = None
sharding: bool = True sharding: bool = True
examples_consume: int = -1
enable_tf_data_service: bool = False enable_tf_data_service: bool = False
tf_data_service_address: Optional[str] = None tf_data_service_address: Optional[str] = None
tf_data_service_job_name: Optional[str] = None tf_data_service_job_name: Optional[str] = None
...@@ -126,8 +124,6 @@ class RuntimeConfig(base_config.Config): ...@@ -126,8 +124,6 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance. persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
""" """
distribution_strategy: str = "mirrored" distribution_strategy: str = "mirrored"
enable_xla: bool = False enable_xla: bool = False
...@@ -145,6 +141,15 @@ class RuntimeConfig(base_config.Config): ...@@ -145,6 +141,15 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
@dataclasses.dataclass @dataclasses.dataclass
class TensorboardConfig(base_config.Config): class TensorboardConfig(base_config.Config):
...@@ -167,12 +172,15 @@ class CallbacksConfig(base_config.Config): ...@@ -167,12 +172,15 @@ class CallbacksConfig(base_config.Config):
Attributes: Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True. Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True. Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks. enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True. Defaults to True.
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True enable_tensorboard: bool = True
enable_time_history: bool = True enable_time_history: bool = True
...@@ -187,6 +195,8 @@ class TrainerConfig(base_config.Config): ...@@ -187,6 +195,8 @@ class TrainerConfig(base_config.Config):
train_tf_while_loop: whether or not to use tf while loop. train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop. train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval. eval_tf_function: whether or not to use tf_function for eval.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
steps_per_loop: number of steps per loop. steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary. summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
...@@ -194,7 +204,7 @@ class TrainerConfig(base_config.Config): ...@@ -194,7 +204,7 @@ class TrainerConfig(base_config.Config):
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default is only used continuous_train_and_eval and continuous_eval modes. Default
value is 24 hrs. value is 1 hrs.
train_steps: number of train steps. train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset validation_steps: number of eval steps. If `None`, the entire eval dataset
is used. is used.
...@@ -223,7 +233,7 @@ class TrainerConfig(base_config.Config): ...@@ -223,7 +233,7 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager. # Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: int = 24 * 60 * 60 continuous_eval_timeout: int = 60 * 60
# Train/Eval routines. # Train/Eval routines.
train_steps: int = 0 train_steps: int = 0
validation_steps: Optional[int] = None validation_steps: Optional[int] = None
......
...@@ -26,15 +26,15 @@ class OptimizerConfigTest(tf.test.TestCase): ...@@ -26,15 +26,15 @@ class OptimizerConfigTest(tf.test.TestCase):
def test_no_optimizer(self): def test_no_optimizer(self):
optimizer = optimization_config.OptimizationConfig({}).optimizer.get() optimizer = optimization_config.OptimizationConfig({}).optimizer.get()
self.assertEqual(optimizer, None) self.assertIsNone(optimizer)
def test_no_lr_schedule(self): def test_no_lr_schedule(self):
lr = optimization_config.OptimizationConfig({}).learning_rate.get() lr = optimization_config.OptimizationConfig({}).learning_rate.get()
self.assertEqual(lr, None) self.assertIsNone(lr)
def test_no_warmup_schedule(self): def test_no_warmup_schedule(self):
warmup = optimization_config.OptimizationConfig({}).warmup.get() warmup = optimization_config.OptimizationConfig({}).warmup.get()
self.assertEqual(warmup, None) self.assertIsNone(warmup)
def test_config(self): def test_config(self):
opt_config = optimization_config.OptimizationConfig({ opt_config = optimization_config.OptimizationConfig({
......
...@@ -21,7 +21,21 @@ from official.modeling.hyperparams import base_config ...@@ -21,7 +21,21 @@ from official.modeling.hyperparams import base_config
@dataclasses.dataclass @dataclasses.dataclass
class SGDConfig(base_config.Config): class BaseOptimizerConfig(base_config.Config):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
"""
clipnorm: Optional[float] = None
clipvalue: Optional[float] = None
@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer. """Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD. The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
...@@ -39,7 +53,7 @@ class SGDConfig(base_config.Config): ...@@ -39,7 +53,7 @@ class SGDConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class RMSPropConfig(base_config.Config): class RMSPropConfig(BaseOptimizerConfig):
"""Configuration for RMSProp optimizer. """Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -60,7 +74,7 @@ class RMSPropConfig(base_config.Config): ...@@ -60,7 +74,7 @@ class RMSPropConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(base_config.Config): class AdamConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -82,7 +96,7 @@ class AdamConfig(base_config.Config): ...@@ -82,7 +96,7 @@ class AdamConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class AdamWeightDecayConfig(base_config.Config): class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay. """Configuration for Adam optimizer with weight decay.
Attributes: Attributes:
...@@ -110,7 +124,7 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -110,7 +124,7 @@ class AdamWeightDecayConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class LAMBConfig(base_config.Config): class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer. """Configuration for LAMB optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -139,7 +153,7 @@ class LAMBConfig(base_config.Config): ...@@ -139,7 +153,7 @@ class LAMBConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class EMAConfig(base_config.Config): class EMAConfig(BaseOptimizerConfig):
"""Exponential moving average optimizer config. """Exponential moving average optimizer config.
Attributes: Attributes:
......
...@@ -144,6 +144,12 @@ class OptimizerFactory(object): ...@@ -144,6 +144,12 @@ class OptimizerFactory(object):
""" """
optimizer_dict = self._optimizer_config.as_dict() optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm and clipvalue if None
if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue']
optimizer_dict['learning_rate'] = lr optimizer_dict['learning_rate'] = lr
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for optimizer_factory.py.""" """Tests for optimizer_factory.py."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling.optimization import optimizer_factory from official.modeling.optimization import optimizer_factory
...@@ -50,6 +49,49 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -50,6 +49,49 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
@parameterized.parameters(
(None, None),
(1.0, None),
(None, 1.0))
def test_gradient_clipping(self, clipnorm, clipvalue):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'clipnorm': clipnorm,
'clipvalue': clipvalue
}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 1.0
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([2.0, 3.0])
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
optimizer.apply_gradients(grads_and_vars)
self.assertAllClose(np.array([0.9, 1.9]), var0.numpy())
if clipvalue is not None:
self.assertAllClose(np.array([2.0, 3.0]), var1.numpy())
elif clipnorm is not None:
self.assertAllClose(np.array([2.4452999, 3.1679497]), var1.numpy())
else:
self.assertAllClose(np.array([1.0, 1.0]), var1.numpy())
def test_missing_types(self): def test_missing_types(self):
params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}} params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -31,7 +31,7 @@ import tensorflow as tf ...@@ -31,7 +31,7 @@ import tensorflow as tf
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils from official.common import distribute_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -745,8 +745,8 @@ class ExecutorBuilder(object): ...@@ -745,8 +745,8 @@ class ExecutorBuilder(object):
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(strategy_config.worker_hosts, _ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index) strategy_config.task_index)
"""Constructor. """Constructor.
Args: Args:
...@@ -756,7 +756,7 @@ class ExecutorBuilder(object): ...@@ -756,7 +756,7 @@ class ExecutorBuilder(object):
strategy_config: necessary config for constructing the proper Strategy. strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure. Check strategy_flags_dict() for examples of the structure.
""" """
self._strategy = distribution_utils.get_distribution_strategy( self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=strategy_type, distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus, num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg, all_reduce_alg=strategy_config.all_reduce_alg,
......
...@@ -40,8 +40,7 @@ class AlbertConfig(configs.BertConfig): ...@@ -40,8 +40,7 @@ class AlbertConfig(configs.BertConfig):
super(AlbertConfig, self).__init__(**kwargs) super(AlbertConfig, self).__init__(**kwargs)
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1 # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder # in the released ALBERT. Support other values in AlbertEncoder if needed.
# if needed.
if inner_group_num != 1 or num_hidden_groups != 1: if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and " raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.") "'num_hidden_groups' as 1.")
......
...@@ -26,11 +26,10 @@ from absl import app ...@@ -26,11 +26,10 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -77,7 +76,7 @@ def main(_): ...@@ -77,7 +76,7 @@ def main(_):
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
......
...@@ -27,12 +27,11 @@ from absl import app ...@@ -27,12 +27,11 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string( flags.DEFINE_string(
'sp_model_file', None, 'sp_model_file', None,
...@@ -104,9 +103,8 @@ def main(_): ...@@ -104,9 +103,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0: if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy(
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint. """A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used The conversion will yield an object-oriented checkpoint that can be used
to restore a AlbertTransformerEncoder object. to restore an AlbertEncoder object.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -81,7 +81,7 @@ def _create_albert_model(cfg): ...@@ -81,7 +81,7 @@ def _create_albert_model(cfg):
Returns: Returns:
A keras model. A keras model.
""" """
albert_encoder = networks.AlbertTransformerEncoder( albert_encoder = networks.AlbertEncoder(
vocab_size=cfg.vocab_size, vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size, hidden_size=cfg.hidden_size,
embedding_width=cfg.embedding_size, embedding_width=cfg.embedding_size,
......
...@@ -167,7 +167,7 @@ def get_transformer_encoder(bert_config, ...@@ -167,7 +167,7 @@ def get_transformer_encoder(bert_config,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)) stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig): if isinstance(bert_config, albert_configs.AlbertConfig):
return networks.AlbertTransformerEncoder(**kwargs) return networks.AlbertEncoder(**kwargs)
else: else:
assert isinstance(bert_config, configs.BertConfig) assert isinstance(bert_config, configs.BertConfig)
kwargs['output_range'] = output_range kwargs['output_range'] = output_range
......
...@@ -285,5 +285,22 @@ def create_retrieval_dataset(file_path, ...@@ -285,5 +285,22 @@ def create_retrieval_dataset(file_path,
_select_data_from_record, _select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False) dataset = dataset.batch(batch_size, drop_remainder=False)
def _pad_to_batch(x, y):
cur_size = tf.shape(y)[0]
pad_size = batch_size - cur_size
pad_ids = tf.zeros(shape=[pad_size, seq_length], dtype=tf.int32)
for key in ('input_word_ids', 'input_mask', 'input_type_ids'):
x[key] = tf.concat([x[key], pad_ids], axis=0)
pad_labels = -tf.ones(shape=[pad_size, 1], dtype=tf.int32)
y = tf.concat([y, pad_labels], axis=0)
return x, y
dataset = dataset.map(
_pad_to_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment