Unverified Commit 441c9bca authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Cleanup TPU-ization of Transformer (#4459)

* add tests for matmul embedding and schedule manager, as well as some minor cleanup

* delint

* address PR comments
parent 97760186
......@@ -74,13 +74,16 @@ class EmbeddingSharedWeights(tf.layers.Layer):
if self.method == "gather":
embeddings = tf.gather(self.shared_weights, x)
embeddings *= tf.expand_dims(mask, -1)
else: # matmul
embeddings = tpu_utils.embedding_matmul(
embedding_table=self.shared_weights,
values=tf.cast(x, dtype=tf.int32),
mask=mask
)
embeddings *= tf.expand_dims(mask, -1)
# embedding_matmul already zeros out masked positions, so
# `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.
# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5
......
......@@ -132,20 +132,3 @@ class Manager(object):
assert self.use_tpu, "epochs_to_steps should only be reached when using TPU"
total_num_tokens = NUM_EXAMPLES[mode] * self.max_length * num_epochs
return total_num_tokens // self.batch_size
def _sleep_if_tpu(self):
"""Sleep for a minute if TPUs are used.
There is currently an issue with TPUs where starting a train or evaluation
before all of the TPU queues have cleared causes the TPU to freeze. This
is a temporary workaround until the issue can be properly resolved.
"""
if self.use_tpu:
tf.logging.info("Sleeping to allow TPU queues to clear.")
time.sleep(60)
def post_train(self):
self._sleep_if_tpu()
def post_eval(self):
self._sleep_if_tpu()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer's schedule manager."""
import tensorflow as tf
from official.transformer.utils import schedule
class ScheduleBaseTester(tf.test.TestCase):
def test_mutual_exclusivity(self):
with self.assertRaises(ValueError):
schedule.Manager(
train_steps=100, steps_between_evals=100, train_epochs=2,
epochs_between_evals=1, default_train_epochs=None, batch_size=2048,
max_length=256)
def test_step_basis(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256)
self.assertEqual(manager.single_iteration_train_steps, 100)
# Evaluation uses the full set
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256)
# For non-TPU, estimator relies on dataset exhausion
self.assertIsNone(manager.single_iteration_train_steps)
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertEqual(manager.repeat_dataset, 2)
def test_step_basis_tpu(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(manager.single_iteration_train_steps, 100)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis_tpu(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(
manager.single_iteration_train_steps,
schedule.NUM_EXAMPLES[tf.estimator.ModeKeys.TRAIN] * 2 // (2048 / 256)
)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertEqual(manager.repeat_dataset, 2)
if __name__ == "__main__":
tf.test.main()
......@@ -90,7 +90,8 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted
range(num_embedding_table_rows).
range(num_embedding_table_rows). All masked positions will produce an
embedding vector of zeros.
Args:
embedding_table: Tensor of embedding table.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test TPU optimized matmul embedding."""
import numpy as np
import tensorflow as tf
from official.utils.accelerator import tpu as tpu_utils
TEST_CASES = [
dict(embedding_dim=256, vocab_size=1000, sequence_length=64,
batch_size=32, seed=54131),
dict(embedding_dim=8, vocab_size=15, sequence_length=12,
batch_size=256, seed=536413),
dict(embedding_dim=2048, vocab_size=512, sequence_length=50,
batch_size=8, seed=35124)
]
class TPUBaseTester(tf.test.TestCase):
def construct_embedding_and_values(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
np.random.seed(seed)
embeddings = np.random.random(size=(vocab_size, embedding_dim))
embedding_table = tf.convert_to_tensor(embeddings, dtype=tf.float32)
tokens = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, sequence_length))
for i in range(batch_size):
tokens[i, np.random.randint(low=0, high=sequence_length-1):] = 0
values = tf.convert_to_tensor(tokens, dtype=tf.int32)
mask = tf.to_float(tf.not_equal(values, 0))
return embedding_table, values, mask
def _test_embedding(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding matches embedding lookup (gather)."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) *
tf.expand_dims(mask, -1))
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(embedding, matmul_embedding)
def _test_masking(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding properly zeros masked positions."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(matmul_embedding,
matmul_embedding * tf.expand_dims(mask, -1))
def test_embedding_0(self):
self._test_embedding(**TEST_CASES[0])
def test_embedding_1(self):
self._test_embedding(**TEST_CASES[1])
def test_embedding_2(self):
self._test_embedding(**TEST_CASES[2])
def test_masking_0(self):
self._test_masking(**TEST_CASES[0])
def test_masking_1(self):
self._test_masking(**TEST_CASES[1])
def test_masking_2(self):
self._test_masking(**TEST_CASES[2])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment