"vscode:/vscode.git/clone" did not exist on "15c8830416b4278934c942012d85a39ea76c783c"
Unverified Commit 1e2ceffd authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #4 from tensorflow/master

Updating 
parents 51e60bab c7adbbe4
...@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax ...@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class Attention(tf.keras.layers.Layer): class MultiHeadAttention(tf.keras.layers.Layer):
"""Attention layer. """MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then is all you Need". If `from_tensor` and `to_tensor` are the same, then
...@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(Attention, self).__init__(**kwargs) super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._head_size = head_size
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
...@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint)
} }
base_config = super(Attention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(Attention): class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
Arguments: Arguments:
......
...@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention ...@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class AttentionLayerTest(keras_parameterized.TestCase): class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self): def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor.""" """Test that the attention layer can be created without a mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80)) to_tensor = tf.keras.Input(shape=(20, 80))
...@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor.""" """Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([from_tensor, from_tensor])
...@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_masked_attention(self): def test_masked_attention(self):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.Attention(num_heads=2, head_size=2) test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8)) from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8)) to_tensor = tf.keras.Input(shape=(2, 8))
...@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = attention.Attention( test_layer = attention.MultiHeadAttention(
num_heads=12, num_heads=12,
head_size=64, head_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
......
...@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.Attention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
head_size=self._attention_head_size, head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout_rate=self._attention_dropout_rate,
......
...@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
num_attention_heads, num_attention_heads,
intermediate_size, intermediate_size,
intermediate_activation, intermediate_activation,
attention_cls=attention.Attention, attention_cls=attention.MultiHeadAttention,
attention_cfg=None, attention_cfg=None,
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
......
...@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold ...@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold
# boolean 'True'. We register this class as a Keras serializable so we can # boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below. # test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnly') @tf.keras.utils.register_keras_serializable(package='TestOnly')
class ValidatedAttentionLayer(attention.Attention): class ValidatedAttentionLayer(attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs): def __init__(self, call_list, **kwargs):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
......
...@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -200,5 +200,4 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -200,5 +200,4 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -26,12 +26,11 @@ import re ...@@ -26,12 +26,11 @@ import re
import sys import sys
import unicodedata import unicodedata
# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import six
from six.moves import range
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.nlp.transformer.utils import metrics from official.nlp.transformer.utils import metrics
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
......
...@@ -205,6 +205,12 @@ def define_transformer_flags(): ...@@ -205,6 +205,12 @@ def define_transformer_flags():
'use padded_decode, it has not been tested. In addition, this method ' 'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with ' 'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.')) 'the max sequence length.'))
flags.DEFINE_bool(
name='enable_checkpointing',
default=True,
help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
......
...@@ -59,5 +59,4 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -59,5 +59,4 @@ class ModelUtilsTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -94,5 +94,4 @@ class TransformerLayersTest(tf.test.TestCase): ...@@ -94,5 +94,4 @@ class TransformerLayersTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main() tf.test.main()
...@@ -159,6 +159,7 @@ class TransformerTask(object): ...@@ -159,6 +159,7 @@ class TransformerTask(object):
params["enable_tensorboard"] = flags_obj.enable_tensorboard params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
...@@ -313,6 +314,8 @@ class TransformerTask(object): ...@@ -313,6 +314,8 @@ class TransformerTask(object):
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step) current_step)
if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
checkpoint_name = checkpoint.save( checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir, os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step))) "ctl_step_{}.ckpt".format(current_step)))
...@@ -397,6 +400,7 @@ class TransformerTask(object): ...@@ -397,6 +400,7 @@ class TransformerTask(object):
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps) scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks(params["steps_between_evals"]) callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback) callbacks.append(scheduler_callback)
if params["enable_checkpointing"]:
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt") ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append( callbacks.append(
tf.keras.callbacks.ModelCheckpoint( tf.keras.callbacks.ModelCheckpoint(
...@@ -470,7 +474,6 @@ def main(_): ...@@ -470,7 +474,6 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
misc.define_transformer_flags() misc.define_transformer_flags()
app.run(main) app.run(main)
...@@ -187,5 +187,4 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -187,5 +187,4 @@ class TransformerTaskTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
tf.test.main() tf.test.main()
...@@ -65,5 +65,4 @@ class TransformerV2Test(tf.test.TestCase): ...@@ -65,5 +65,4 @@ class TransformerV2Test(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main() tf.test.main()
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections import collections
import tempfile import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
......
...@@ -454,6 +454,4 @@ def main(_): ...@@ -454,6 +454,4 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main) app.run(main)
...@@ -193,5 +193,4 @@ def main(unused_argv): ...@@ -193,5 +193,4 @@ def main(unused_argv):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main) app.run(main)
...@@ -153,5 +153,4 @@ def main(unused_argv): ...@@ -153,5 +153,4 @@ def main(unused_argv):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment