"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a5cfcb93ffbe52fef49ecdcfd1ce01974799694e"
Commit cb8ce606 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 52372782 62184a96
...@@ -185,6 +185,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -185,6 +185,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.early_stopping = True FLAGS.early_stopping = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_ctl_run_eagerly_early_stop(self):
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_ctl_early_stop(self): def benchmark_xla_1_gpu_ctl_early_stop(self):
self._setup() self._setup()
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
...@@ -207,7 +214,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -207,7 +214,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
############################################# #############################################
# Tests below with mlperf in the test name are of two types # Tests below with mlperf in the test name are of two types:
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission. # 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters. # 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
# #
...@@ -258,6 +265,14 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -258,6 +265,14 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7 FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like() self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
"""1 GPU using CTL with eager and distribution strategy."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.run_eagerly = True
FLAGS.train_epochs = 7
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_ctl_mlperf_like(self): def benchmark_xla_1_gpu_ctl_mlperf_like(self):
"""1 GPU using CTL with XLA.""" """1 GPU using CTL with XLA."""
self._setup() self._setup()
......
...@@ -285,7 +285,6 @@ def run_ncf(_): ...@@ -285,7 +285,6 @@ def run_ncf(_):
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset) train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset) eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
@tf.function
def train_step(): def train_step():
"""Called once per step to train the model.""" """Called once per step to train the model."""
def step_fn(features): def step_fn(features):
...@@ -310,7 +309,6 @@ def run_ncf(_): ...@@ -310,7 +309,6 @@ def run_ncf(_):
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss return mean_loss
@tf.function
def eval_step(): def eval_step():
"""Called once per eval step to compute eval metrics.""" """Called once per eval step to compute eval metrics."""
def step_fn(features): def step_fn(features):
...@@ -330,6 +328,10 @@ def run_ncf(_): ...@@ -330,6 +328,10 @@ def run_ncf(_):
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
return hr_sum, hr_count return hr_sum, hr_count
if not FLAGS.run_eagerly:
train_step = tf.function(train_step)
eval_step = tf.function(eval_step)
time_callback.on_train_begin() time_callback.on_train_begin()
for epoch in range(FLAGS.train_epochs): for epoch in range(FLAGS.train_epochs):
for cb in callbacks: for cb in callbacks:
......
...@@ -18,11 +18,31 @@ Source implementation from Tensor2Tensor: ...@@ -18,11 +18,31 @@ Source implementation from Tensor2Tensor:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
""" """
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import nest from tensorflow.python.util import nest
# Default value for INF
INF = 1. * 1e7 def inf(dtype):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied by
zero. The floating-point "Inf" value is NaN when multiplied by zero.
Args:
dtype: A dtype. The returned value will be finite when casted to this dtype.
Returns:
A very large value.
"""
if dtype == "float32":
return 1e7
elif dtype == "float16":
# Disable no-member lint error, as the linter thinks np.float16 does not
# exist for some reason.
return np.finfo(np.float16).max # pylint: disable=no-member
else:
raise AssertionError('Invalid dtype: %s' % dtype)
class _StateKeys(object): class _StateKeys(object):
...@@ -60,7 +80,7 @@ class SequenceBeamSearch(object): ...@@ -60,7 +80,7 @@ class SequenceBeamSearch(object):
"""Implementation of beam search loop.""" """Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, def __init__(self, symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id): beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32):
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch_size = batch_size self.batch_size = batch_size
...@@ -68,6 +88,7 @@ class SequenceBeamSearch(object): ...@@ -68,6 +88,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha self.alpha = alpha
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.eos_id = eos_id self.eos_id = eos_id
self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache): def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores.""" """Beam search for sequences with highest scores."""
...@@ -105,6 +126,14 @@ class SequenceBeamSearch(object): ...@@ -105,6 +126,14 @@ class SequenceBeamSearch(object):
Returns: Returns:
state and shape invariant dictionaries with keys from _StateKeys state and shape invariant dictionaries with keys from _StateKeys
""" """
for key, value in initial_cache.items():
for inner_value in nest.flatten(value):
if inner_value.dtype != self.dtype:
raise TypeError(
"initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s" %
(key, value.dtype.name, self.dtype.name, inner_value))
# Current loop index (starts at 0) # Current loop index (starts at 0)
cur_index = tf.constant(0) cur_index = tf.constant(0)
...@@ -115,7 +144,7 @@ class SequenceBeamSearch(object): ...@@ -115,7 +144,7 @@ class SequenceBeamSearch(object):
# Create tensor for storing initial log probabilities. # Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0 # Assume initial_ids are prob 1.0
initial_log_probs = tf.constant( initial_log_probs = tf.constant(
[[0.] + [-float("inf")] * (self.beam_size - 1)]) [[0.] + [-float("inf")] * (self.beam_size - 1)], dtype=self.dtype)
alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1]) alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])
# Expand all values stored in the dictionary to the beam size, so that each # Expand all values stored in the dictionary to the beam size, so that each
...@@ -127,7 +156,8 @@ class SequenceBeamSearch(object): ...@@ -127,7 +156,8 @@ class SequenceBeamSearch(object):
finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
# Set scores of the initial finished seqs to negative infinity. # Set scores of the initial finished seqs to negative infinity.
finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF finished_scores = tf.ones([self.batch_size, self.beam_size],
dtype=self.dtype) * -inf(self.dtype)
# Initialize finished flags with all False values. # Initialize finished flags with all False values.
finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool) finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)
...@@ -185,20 +215,22 @@ class SequenceBeamSearch(object): ...@@ -185,20 +215,22 @@ class SequenceBeamSearch(object):
not_at_max_decode_length = tf.less(i, self.max_decode_length) not_at_max_decode_length = tf.less(i, self.max_decode_length)
# Calculate largest length penalty (the larger penalty, the better score). # Calculate largest length penalty (the larger penalty, the better score).
max_length_norm = _length_normalization(self.alpha, self.max_decode_length) max_length_norm = _length_normalization(self.alpha, self.max_decode_length,
dtype=self.dtype)
# Get the best possible scores from alive sequences. # Get the best possible scores from alive sequences.
best_alive_scores = alive_log_probs[:, 0] / max_length_norm best_alive_scores = alive_log_probs[:, 0] / max_length_norm
# Compute worst score in finished sequences for each batch element # Compute worst score in finished sequences for each batch element
finished_scores *= tf.cast(finished_flags, finished_scores *= tf.cast(finished_flags,
tf.float32) # set filler scores to zero self.dtype) # set filler scores to zero
lowest_finished_scores = tf.reduce_min(finished_scores, axis=1) lowest_finished_scores = tf.reduce_min(finished_scores, axis=1)
# If there are no finished sequences in a batch element, then set the lowest # If there are no finished sequences in a batch element, then set the lowest
# finished score to -INF for that element. # finished score to -INF for that element.
finished_batches = tf.reduce_any(finished_flags, 1) finished_batches = tf.reduce_any(finished_flags, 1)
lowest_finished_scores += (1.0 - lowest_finished_scores += ((1.0 -
tf.cast(finished_batches, tf.float32)) * -INF tf.cast(finished_batches, self.dtype)) *
-inf(self.dtype))
worst_finished_score_better_than_best_alive_score = tf.reduce_all( worst_finished_score_better_than_best_alive_score = tf.reduce_all(
tf.greater(lowest_finished_scores, best_alive_scores) tf.greater(lowest_finished_scores, best_alive_scores)
...@@ -319,9 +351,9 @@ class SequenceBeamSearch(object): ...@@ -319,9 +351,9 @@ class SequenceBeamSearch(object):
Log probabilities of top alive sequences Log probabilities of top alive sequences
Dict cache storing decoder states for top alive sequences} Dict cache storing decoder states for top alive sequences}
""" """
# To prevent finished sequences from being considered, set log probs to -INF # To prevent finished sequences from being considered, set log probs to -inf
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_log_probs += tf.cast(new_finished_flags, tf.float32) * -INF new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype)
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams( top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
[new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size, [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size,
...@@ -361,12 +393,13 @@ class SequenceBeamSearch(object): ...@@ -361,12 +393,13 @@ class SequenceBeamSearch(object):
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
# Calculate new seq scores from log probabilities. # Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1) length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
new_scores = new_log_probs / length_norm new_scores = new_log_probs / length_norm
# Set the scores of the still-alive seq in new_seq to large negative values. # Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_scores += (1. - tf.cast(new_finished_flags, tf.float32)) * -INF new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) *
-inf(self.dtype))
# Combine sequences, scores, and flags. # Combine sequences, scores, and flags.
finished_seq = tf.concat([finished_seq, new_seq], axis=1) finished_seq = tf.concat([finished_seq, new_seq], axis=1)
...@@ -422,9 +455,9 @@ def _log_prob_from_logits(logits): ...@@ -422,9 +455,9 @@ def _log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True) return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True)
def _length_normalization(alpha, length): def _length_normalization(alpha, length, dtype=tf.float32):
"""Return length normalization factor.""" """Return length normalization factor."""
return tf.pow(((5. + tf.cast(length, tf.float32)) / 6.), alpha) return tf.pow(((5. + tf.cast(length, dtype)) / 6.), alpha)
def _expand_to_beam_size(tensor, beam_size): def _expand_to_beam_size(tensor, beam_size):
......
...@@ -57,7 +57,7 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -57,7 +57,7 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
def sequence_beam_search( def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id): alpha, max_decode_length, eos_id, dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
...@@ -76,7 +76,8 @@ def sequence_beam_search( ...@@ -76,7 +76,8 @@ def sequence_beam_search(
beam_size: int number of beams beam_size: int number of beams
alpha: float defining the strength of length normalization alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished eos_id: int id of eos token, used to determine when a sequence has finished,
dtype: The dtype to use.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
...@@ -85,10 +86,12 @@ def sequence_beam_search( ...@@ -85,10 +86,12 @@ def sequence_beam_search(
batch_size = tf.shape(initial_ids)[0] batch_size = tf.shape(initial_ids)[0]
if misc.is_v2(): if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id) beam_size, alpha, max_decode_length, eos_id,
dtype)
else: else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id) beam_size, alpha, max_decode_length, eos_id,
dtype)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -24,14 +24,24 @@ import tensorflow as tf ...@@ -24,14 +24,24 @@ import tensorflow as tf
class EmbeddingSharedWeights(tf.keras.layers.Layer): class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights.""" """Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size): def __init__(self, vocab_size, hidden_size, dtype=None):
"""Specify characteristic parameters of embedding layer. """Specify characteristic parameters of embedding layer.
Args: Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000) vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024) hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
""" """
super(EmbeddingSharedWeights, self).__init__() if dtype == tf.float16:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype = tf.keras.mixed_precision.experimental.Policy(
"float16_with_float32_vars")
super(EmbeddingSharedWeights, self).__init__(dtype=dtype)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -78,8 +88,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -78,8 +88,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
with tf.name_scope("embedding"): with tf.name_scope("embedding"):
# Create binary mask of size [batch_size, length] # Create binary mask of size [batch_size, length]
mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
embeddings = tf.gather(self.shared_weights, inputs) embeddings = tf.gather(self.shared_weights, inputs)
mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
embeddings *= tf.expand_dims(mask, -1) embeddings *= tf.expand_dims(mask, -1)
# Scale embedding by the sqrt of the hidden size # Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5 embeddings *= self.hidden_size ** 0.5
......
...@@ -32,6 +32,11 @@ from official.transformer.v2 import ffn_layer ...@@ -32,6 +32,11 @@ from official.transformer.v2 import ffn_layer
from official.transformer.v2 import metrics from official.transformer.v2 import metrics
# Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are.
# pylint: disable=not-callable
def create_model(params, is_train): def create_model(params, is_train):
"""Creates transformer model.""" """Creates transformer model."""
with tf.name_scope("model"): with tf.name_scope("model"):
...@@ -80,7 +85,7 @@ class Transformer(tf.keras.Model): ...@@ -80,7 +85,7 @@ class Transformer(tf.keras.Model):
super(Transformer, self).__init__(name=name) super(Transformer, self).__init__(name=name)
self.params = params self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"]) params["vocab_size"], params["hidden_size"], dtype=params["dtype"])
self.encoder_stack = EncoderStack(params) self.encoder_stack = EncoderStack(params)
self.decoder_stack = DecoderStack(params) self.decoder_stack = DecoderStack(params)
...@@ -216,8 +221,9 @@ class Transformer(tf.keras.Model): ...@@ -216,8 +221,9 @@ class Transformer(tf.keras.Model):
timing_signal = model_utils.get_position_encoding( timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params["hidden_size"]) max_decode_length + 1, self.params["hidden_size"])
timing_signal = tf.cast(timing_signal, self.params["dtype"])
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length) max_decode_length, dtype=self.params["dtype"])
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
...@@ -257,12 +263,11 @@ class Transformer(tf.keras.Model): ...@@ -257,12 +263,11 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
# Currently, we always do prediction in float32.
# TODO(reedwm): Add float16 support.
encoder_outputs = tf.cast(encoder_outputs, tf.float32)
batch_size = tf.shape(encoder_outputs)[0] batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1] input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"] max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"])
symbols_to_logits_fn = self._get_symbols_to_logits_fn( symbols_to_logits_fn = self._get_symbols_to_logits_fn(
max_decode_length, training) max_decode_length, training)
...@@ -274,8 +279,10 @@ class Transformer(tf.keras.Model): ...@@ -274,8 +279,10 @@ class Transformer(tf.keras.Model):
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
cache = { cache = {
"layer_%d" % layer: { "layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]]), "k": tf.zeros([batch_size, 0, self.params["hidden_size"]],
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]]) dtype=self.params["dtype"]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"]) } for layer in range(self.params["num_hidden_layers"])
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
...@@ -293,7 +300,8 @@ class Transformer(tf.keras.Model): ...@@ -293,7 +300,8 @@ class Transformer(tf.keras.Model):
beam_size=self.params["beam_size"], beam_size=self.params["beam_size"],
alpha=self.params["alpha"], alpha=self.params["alpha"],
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=EOS_ID) eos_id=EOS_ID,
dtype=self.params["dtype"])
# Get the top sequence for each batch element # Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:] top_decoded_ids = decoded_ids[:, 0, 1:]
......
...@@ -95,6 +95,12 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -95,6 +95,12 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self):
FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment