Unverified Commit b9c1d1ca authored by Igor's avatar Igor Committed by GitHub
Browse files

Add distribute strategies to transformer. (#6883)

* Fixes that make transformer run.

* Remove debug print statements.

* Changed the permissions to 644.

* Fix the rest of the permissions.

* enable static batch in all benchmarks

* Restrict dist strat hack to training mode

For now we will do predict/eval without dist strat, so remove that hack in non training cases.

* Use `inputs` instead of `x` as arg name for call

Keras has different behavior based on whether the inputs are called `inputs` or not. Using `inputs` gives expected behaviors.

* Avoid extra map fn on input in dist strat case

* Update how we handle custom metrics

This new approach works with and without dist strat. The previous one didn't work with dist strat. We need to fix that but this is reasonable in meantime (b/133724664).

* Update benchmarks

* typo in metrics code

* Revert metrics change

Didn't actually work in distributed case..
parent 7af3bd91
......@@ -160,9 +160,12 @@ class MetricLayer(tf.keras.layers.Layer):
def call(self, inputs):
logits, targets = inputs[0], inputs[1]
for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets))
self.add_metric(m)
# TODO(guptapriya): Remove this check when underlying issue to create metrics
# with dist strat in cross replica context is fixed.
if tf.distribute.has_strategy() and not tf.distribute.in_cross_replica_context():
for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets))
self.add_metric(m)
return logits
......
......@@ -85,11 +85,11 @@ class Transformer(tf.keras.Model):
"params": self.params,
}
def call(self, x, training):
def call(self, inputs, training):
"""Calculate target logits or inferred target sequences.
Args:
x: input tensor list of size 1 or 2.
inputs: input tensor list of size 1 or 2.
First item, inputs: int tensor with shape [batch_size, input_length].
Second item (optional), targets: None or int tensor with shape
[batch_size, target_length].
......@@ -103,10 +103,10 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
"""
if len(x) == 2:
inputs, targets = x[0], x[1]
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
else:
inputs, targets = x[0], None
inputs, targets = inputs[0], None
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
......
......@@ -140,7 +140,7 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.train_steps = 200000
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
# These bleu scores are based on test runs after at this limited
......@@ -150,6 +150,144 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
bleu_min=25.3,
bleu_max=26)
def benchmark_1_gpu_static_batch(self):
"""Benchmark 1 gpu with static_batch.
The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
# TODO(guptapriya): Add max_length
FLAGS.static_batch = True
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=25.3,
bleu_max=26)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=27,
bleu_max=28)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
# TODO(guptapriya): Add max_length
FLAGS.static_batch = True
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=27,
bleu_max=28)
class TransformerBigKerasAccuracy(TransformerBenchmark):
"""Benchmark accuracy tests for Transformer Big model w/ Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""Benchmark accuracy tests for Transformer Big model w/ Keras.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
flag_methods = [misc.define_transformer_flags]
super(TransformerBigKerasAccuracy, self).__init__(
output_dir=output_dir, root_data_dir=root_data_dir,
flag_methods=flag_methods)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
# TODO(guptapriya): Add max_length
FLAGS.static_batch = True
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmarks for Transformer (Base and Big) using Keras."""
......@@ -182,6 +320,37 @@ class TransformerKerasBenchmark(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_static_batch(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
# TODO(guptapriya): Add max_length
FLAGS.static_batch = True
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
# TODO(guptapriya): Add max_length
FLAGS.static_batch = True
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
......
......@@ -39,6 +39,7 @@ from official.transformer.v2 import transformer
from official.transformer.v2 import translate
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
INF = int(1e9)
......@@ -92,10 +93,15 @@ class TransformerTask(object):
# Add flag-defined parameters to params object
num_gpus = flags_core.get_num_gpus(flags_obj)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch
params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
......@@ -107,16 +113,25 @@ class TransformerTask(object):
"""Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True
_ensure_dir(flags_obj.model_dir)
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
if self.distribution_strategy:
with self.distribution_strategy.scope():
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
model.compile(opt)
else:
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
model.compile(opt)
model.compile(opt, target_tensors=[])
model.summary()
map_data_fn = data_pipeline.map_data_for_transformer_fn
# TODO(guptapriya): Figure out a way to structure input that works in both
# distributed and non distributed cases.
train_ds = data_pipeline.train_input_fn(params)
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if not self.distribution_strategy:
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment