Unverified Commit 441c9bca authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Cleanup TPU-ization of Transformer (#4459)

* add tests for matmul embedding and schedule manager, as well as some minor cleanup

* delint

* address PR comments
parent 97760186
...@@ -74,13 +74,16 @@ class EmbeddingSharedWeights(tf.layers.Layer): ...@@ -74,13 +74,16 @@ class EmbeddingSharedWeights(tf.layers.Layer):
if self.method == "gather": if self.method == "gather":
embeddings = tf.gather(self.shared_weights, x) embeddings = tf.gather(self.shared_weights, x)
embeddings *= tf.expand_dims(mask, -1)
else: # matmul else: # matmul
embeddings = tpu_utils.embedding_matmul( embeddings = tpu_utils.embedding_matmul(
embedding_table=self.shared_weights, embedding_table=self.shared_weights,
values=tf.cast(x, dtype=tf.int32), values=tf.cast(x, dtype=tf.int32),
mask=mask mask=mask
) )
embeddings *= tf.expand_dims(mask, -1) # embedding_matmul already zeros out masked positions, so
# `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.
# Scale embedding by the sqrt of the hidden size # Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5 embeddings *= self.hidden_size ** 0.5
......
...@@ -132,20 +132,3 @@ class Manager(object): ...@@ -132,20 +132,3 @@ class Manager(object):
assert self.use_tpu, "epochs_to_steps should only be reached when using TPU" assert self.use_tpu, "epochs_to_steps should only be reached when using TPU"
total_num_tokens = NUM_EXAMPLES[mode] * self.max_length * num_epochs total_num_tokens = NUM_EXAMPLES[mode] * self.max_length * num_epochs
return total_num_tokens // self.batch_size return total_num_tokens // self.batch_size
def _sleep_if_tpu(self):
"""Sleep for a minute if TPUs are used.
There is currently an issue with TPUs where starting a train or evaluation
before all of the TPU queues have cleared causes the TPU to freeze. This
is a temporary workaround until the issue can be properly resolved.
"""
if self.use_tpu:
tf.logging.info("Sleeping to allow TPU queues to clear.")
time.sleep(60)
def post_train(self):
self._sleep_if_tpu()
def post_eval(self):
self._sleep_if_tpu()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer's schedule manager."""
import tensorflow as tf
from official.transformer.utils import schedule
class ScheduleBaseTester(tf.test.TestCase):
def test_mutual_exclusivity(self):
with self.assertRaises(ValueError):
schedule.Manager(
train_steps=100, steps_between_evals=100, train_epochs=2,
epochs_between_evals=1, default_train_epochs=None, batch_size=2048,
max_length=256)
def test_step_basis(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256)
self.assertEqual(manager.single_iteration_train_steps, 100)
# Evaluation uses the full set
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256)
# For non-TPU, estimator relies on dataset exhausion
self.assertIsNone(manager.single_iteration_train_steps)
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertEqual(manager.repeat_dataset, 2)
def test_step_basis_tpu(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(manager.single_iteration_train_steps, 100)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis_tpu(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(
manager.single_iteration_train_steps,
schedule.NUM_EXAMPLES[tf.estimator.ModeKeys.TRAIN] * 2 // (2048 / 256)
)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertEqual(manager.repeat_dataset, 2)
if __name__ == "__main__":
tf.test.main()
...@@ -90,7 +90,8 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'): ...@@ -90,7 +90,8 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
The matrix to be multiplied by the embedding table Tensor is constructed The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted and performing an equality comparison against a broadcasted
range(num_embedding_table_rows). range(num_embedding_table_rows). All masked positions will produce an
embedding vector of zeros.
Args: Args:
embedding_table: Tensor of embedding table. embedding_table: Tensor of embedding table.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test TPU optimized matmul embedding."""
import numpy as np
import tensorflow as tf
from official.utils.accelerator import tpu as tpu_utils
TEST_CASES = [
dict(embedding_dim=256, vocab_size=1000, sequence_length=64,
batch_size=32, seed=54131),
dict(embedding_dim=8, vocab_size=15, sequence_length=12,
batch_size=256, seed=536413),
dict(embedding_dim=2048, vocab_size=512, sequence_length=50,
batch_size=8, seed=35124)
]
class TPUBaseTester(tf.test.TestCase):
def construct_embedding_and_values(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
np.random.seed(seed)
embeddings = np.random.random(size=(vocab_size, embedding_dim))
embedding_table = tf.convert_to_tensor(embeddings, dtype=tf.float32)
tokens = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, sequence_length))
for i in range(batch_size):
tokens[i, np.random.randint(low=0, high=sequence_length-1):] = 0
values = tf.convert_to_tensor(tokens, dtype=tf.int32)
mask = tf.to_float(tf.not_equal(values, 0))
return embedding_table, values, mask
def _test_embedding(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding matches embedding lookup (gather)."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) *
tf.expand_dims(mask, -1))
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(embedding, matmul_embedding)
def _test_masking(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding properly zeros masked positions."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(matmul_embedding,
matmul_embedding * tf.expand_dims(mask, -1))
def test_embedding_0(self):
self._test_embedding(**TEST_CASES[0])
def test_embedding_1(self):
self._test_embedding(**TEST_CASES[1])
def test_embedding_2(self):
self._test_embedding(**TEST_CASES[2])
def test_masking_0(self):
self._test_masking(**TEST_CASES[0])
def test_masking_1(self):
self._test_masking(**TEST_CASES[1])
def test_masking_2(self):
self._test_masking(**TEST_CASES[2])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment