Unverified Commit 1e2ceffd authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #4 from tensorflow/master

Updating 
parents 51e60bab c7adbbe4
......@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text")
class Attention(tf.keras.layers.Layer):
"""Attention layer.
class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
......@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Attention, self).__init__(**kwargs)
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._dropout_rate = dropout_rate
......@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(Attention, self).get_config()
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
......@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(Attention):
class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments:
......
......@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class AttentionLayerTest(keras_parameterized.TestCase):
class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
......@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
......@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_masked_attention(self):
"""Test with a mask tensor."""
test_layer = attention.Attention(num_heads=2, head_size=2)
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
......@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.Attention(
test_layer = attention.MultiHeadAttention(
num_heads=12,
head_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
......
......@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.Attention(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
......
......@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
num_attention_heads,
intermediate_size,
intermediate_activation,
attention_cls=attention.Attention,
attention_cls=attention.MultiHeadAttention,
attention_cfg=None,
dropout_rate=0.0,
attention_dropout_rate=0.0,
......
......@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnly')
class ValidatedAttentionLayer(attention.Attention):
class ValidatedAttentionLayer(attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
......
......@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -200,5 +200,4 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -26,12 +26,11 @@ import re
import sys
import unicodedata
# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app
from absl import flags
import six
from six.moves import range
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.nlp.transformer.utils import metrics
from official.nlp.transformer.utils import tokenizer
......
......@@ -205,6 +205,12 @@ def define_transformer_flags():
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags.DEFINE_bool(
name='enable_checkpointing',
default=True,
help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
......
......@@ -59,5 +59,4 @@ class ModelUtilsTest(tf.test.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -94,5 +94,4 @@ class TransformerLayersTest(tf.test.TestCase):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -159,6 +159,7 @@ class TransformerTask(object):
params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -313,10 +314,12 @@ class TransformerTask(object):
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
......@@ -397,10 +400,11 @@ class TransformerTask(object):
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
if params["enable_checkpointing"]:
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
return callbacks
def _load_weights_if_possible(self, model, init_weight_path=None):
......@@ -470,7 +474,6 @@ def main(_):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
misc.define_transformer_flags()
app.run(main)
......@@ -187,5 +187,4 @@ class TransformerTaskTest(tf.test.TestCase):
if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -65,5 +65,4 @@ class TransformerV2Test(tf.test.TestCase):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -17,7 +17,7 @@
import collections
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.nlp.transformer.utils import tokenizer
......
......@@ -454,6 +454,4 @@ def main(_):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -193,5 +193,4 @@ def main(unused_argv):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -153,5 +153,4 @@ def main(unused_argv):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment