Commit f0e2f833 authored by xinliupitt's avatar xinliupitt
Browse files

remove params

parent 9bb04e60
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model in TF 2.0.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
TF official NLP library:
https://github.com/tensorflow/models/tree/master/official/nlp/modeling
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
......@@ -29,6 +26,7 @@ from official.nlp.modeling import layers
from official.nlp.modeling.layers import position_embedding
from official.nlp.modeling.layers import transformer
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils
from official.nlp.transformer.utils.tokenizer import EOS_ID
......@@ -37,6 +35,66 @@ from official.nlp.transformer.utils.tokenizer import EOS_ID
# callable when they actually are.
# pylint: disable=not-callable
def create_model(params, is_train):
"""Creates transformer model."""
encdec_kwargs = dict(
num_layers=params["num_hidden_layers"],
num_attention_heads=params["num_heads"],
intermediate_size=params["filter_size"],
activation="relu",
dropout_rate=params["relu_dropout"],
attention_dropout_rate=params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=params["relu_dropout"])
encoder_layer = TransformerEncoder(**encdec_kwargs)
decoder_layer = TransformerDecoder(**encdec_kwargs)
model_kwargs = dict(
vocab_size=params["vocab_size"],
hidden_size=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
num_replicas=params["num_replicas"],
decode_batch_size=params["decode_batch_size"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
num_heads=params["num_heads"],
num_layers=params["num_hidden_layers"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
with tf.name_scope("model"):
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype=tf.float32)(logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size)
model.add_loss(loss)
return model
else:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
internal_model = Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
......@@ -49,54 +107,108 @@ class Seq2SeqTransformer(tf.keras.Model):
probabilities for the output sequence.
"""
def __init__(self, params, name=None):
def __init__(self,
vocab_size=33708,
hidden_size=512,
dropout_rate=0.0,
padded_decode=False,
num_replicas=1,
decode_batch_size=2048,
decode_max_length=97,
dtype=tf.float32,
extra_decode_length=0,
num_heads=8,
num_layers=6,
beam_size=4,
alpha=0.6,
encoder_layer=None,
decoder_layer=None,
name=None,
**kwargs):
"""Initialize layers to build Transformer model.
Args:
params: hyperparameter object defining layer sizes, dropout values, etc.
Arguments:
vocab_size: Size of vocabulary.
hidden_size: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
num_replicas: Number of replicas for distribution strategy.
decode_batch_size: batch_size for decoding.
decode_max_length: maximum number of steps to decode a sequence.
dtype: data type.
num_heads: Number of attention heads.
num_layers: Number of identical layers for Transformer architecture.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
name: name of the model.
"""
super(Seq2SeqTransformer, self).__init__(name=name)
self.params = params
super(Seq2SeqTransformer, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._dropout_rate = dropout_rate
self._padded_decode = padded_decode
self._num_replicas = num_replicas
self._decode_batch_size = decode_batch_size
self._decode_max_length = decode_max_length
self._dtype = dtype
self._extra_decode_length = extra_decode_length
self._num_heads = num_heads
self._num_layers = num_layers
self._beam_size = beam_size
self._alpha = alpha
self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=params["vocab_size"],
embedding_width=params["hidden_size"],
vocab_size=self._vocab_size,
embedding_width=self._hidden_size,
initializer=tf.random_normal_initializer(
mean=0., stddev=params["hidden_size"]**-0.5),
mean=0., stddev=self._hidden_size**-0.5),
use_scale=True)
self.encoder_layer = TransformerEncoder(
num_layers=self.params["num_hidden_layers"],
num_attention_heads=self.params["num_heads"],
intermediate_size=self.params["filter_size"],
activation="relu",
dropout_rate=self.params["relu_dropout"],
attention_dropout_rate=self.params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=self.params["relu_dropout"])
self.decoder_layer = TransformerDecoder(
num_layers=self.params["num_hidden_layers"],
num_attention_heads=self.params["num_heads"],
intermediate_size=self.params["filter_size"],
activation="relu",
dropout_rate=self.params["relu_dropout"],
attention_dropout_rate=self.params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=self.params["relu_dropout"])
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = position_embedding.RelativePositionEmbedding(
hidden_size=self.params["hidden_size"])
hidden_size=self._hidden_size)
self.encoder_dropout = tf.keras.layers.Dropout(
rate=self.params["layer_postprocess_dropout"])
rate=self._dropout_rate)
self.decoder_dropout = tf.keras.layers.Dropout(
rate=self.params["layer_postprocess_dropout"])
rate=self._dropout_rate)
def get_config(self):
return {
"params": self.params,
config = {
"vocab_size":
self._vocab_size,
"hidden_size":
self._hidden_size,
"dropout_rate":
self._dropout_rate,
"padded_decode":
self._padded_decode,
"num_replicas":
self._num_replicas,
"decode_batch_size":
self._decode_batch_size,
"decode_max_length":
self._decode_max_length,
"dtype":
self._dtype,
"extra_decode_length":
self._extra_decode_length,
"num_heads":
self._num_heads,
"num_layers":
self._num_layers,
"beam_size":
self._beam_size,
"alpha":
self._alpha,
"encoder_layer":
self.encoder_layer,
"decoder_layer":
self.decoder_layer
}
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
......@@ -124,19 +236,19 @@ class Seq2SeqTransformer(tf.keras.Model):
else:
# Decoding path.
inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
if self._padded_decode:
if not self._num_replicas:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
decode_batch_size = int(self._decode_batch_size /
self._num_replicas)
inputs.set_shape([
decode_batch_size, self.params["decode_max_length"]
decode_batch_size, self._decode_max_length
])
with tf.name_scope("Transformer"):
attention_bias = model_utils.get_padding_bias(inputs)
attention_bias = tf.cast(attention_bias, self.params["dtype"])
attention_bias = tf.cast(attention_bias, self._dtype)
with tf.name_scope("encode"):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
......@@ -144,7 +256,7 @@ class Seq2SeqTransformer(tf.keras.Model):
embedding_mask = tf.cast(tf.not_equal(inputs, 0),
self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"])
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
......@@ -158,7 +270,7 @@ class Seq2SeqTransformer(tf.keras.Model):
with tf.name_scope("add_pos_encoding"):
pos_encoding = self.position_embedding(inputs=embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
pos_encoding = tf.cast(pos_encoding, self._dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
......@@ -168,16 +280,16 @@ class Seq2SeqTransformer(tf.keras.Model):
if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
if self.params["padded_decode"]:
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
max_decode_length = input_length + self._extra_decode_length
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"])
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(
max_decode_length)
......@@ -188,9 +300,9 @@ class Seq2SeqTransformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
num_heads = self.params["num_heads"]
dim_per_head = self.params["hidden_size"] // num_heads
max_decode_length if self._padded_decode else 0)
num_heads = self._num_heads
dim_per_head = self._hidden_size // num_heads
cache = {
str(layer): {
......@@ -198,13 +310,13 @@ class Seq2SeqTransformer(tf.keras.Model):
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"]),
dtype=self._dtype),
"value":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
dtype=self._dtype)
} for layer in range(self._num_layers)
}
# pylint: enable=g-complex-comprehension
......@@ -218,13 +330,13 @@ class Seq2SeqTransformer(tf.keras.Model):
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params["vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"])
padded_decode=self._padded_decode,
dtype=self._dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
......@@ -238,7 +350,7 @@ class Seq2SeqTransformer(tf.keras.Model):
embedding_mask = tf.cast(tf.not_equal(targets, 0),
self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"])
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs,
......@@ -246,7 +358,7 @@ class Seq2SeqTransformer(tf.keras.Model):
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
pos_encoding = tf.cast(pos_encoding, self._dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
......@@ -282,9 +394,9 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self.params["dtype"])
timing_signal = tf.cast(timing_signal, self._dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"])
max_decode_length, dtype=self._dtype)
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
......@@ -312,7 +424,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self.embedding_lookup.embeddings.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1)
if self.params["padded_decode"]:
if self._padded_decode:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
......@@ -350,7 +462,7 @@ class Seq2SeqTransformer(tf.keras.Model):
memory_mask=self_attention_mask,
target_mask=attention_mask,
cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
decode_loop_step=i if self._padded_decode else None)
logits = embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
......@@ -392,8 +504,9 @@ class TransformerEncoder(tf.keras.layers.Layer):
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0):
super(TransformerEncoder, self).__init__()
intermediate_dropout=0.0,
**kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self._num_layers = num_layers
self._num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
......@@ -507,8 +620,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0):
super(TransformerDecoder, self).__init__()
intermediate_dropout=0.0,
**kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self._num_layers = num_layers
self._num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
......
......@@ -28,7 +28,6 @@ class TransformerV2Test(tf.test.TestCase):
def setUp(self):
self.params = params = model_params.TINY_PARAMS
params["batch_size"] = params["default_batch_size"] = 16
params["use_synthetic_data"] = True
params["hidden_size"] = 12
params["num_hidden_layers"] = 2
params["filter_size"] = 14
......@@ -39,11 +38,7 @@ class TransformerV2Test(tf.test.TestCase):
params["dtype"] = tf.float32
def test_create_model_train(self):
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = seq2seq_transformer.Seq2SeqTransformer(self.params)
logits = internal_model([inputs, targets], training=True)
model = tf.keras.Model([inputs, targets], logits)
model = seq2seq_transformer.create_model(self.params, True)
inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 2)
self.assertEqual(len(outputs), 1)
......@@ -55,11 +50,7 @@ class TransformerV2Test(tf.test.TestCase):
self.assertEqual(outputs[0].dtype, tf.float32)
def test_create_model_not_train(self):
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
internal_model = seq2seq_transformer.Seq2SeqTransformer(self.params)
ret = internal_model([inputs], training=False)
outputs, scores = ret["outputs"], ret["scores"]
model = tf.keras.Model(inputs, [outputs, scores])
model = seq2seq_transformer.create_model(self.params, False)
inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 1)
self.assertEqual(len(outputs), 2)
......@@ -71,5 +62,6 @@ class TransformerV2Test(tf.test.TestCase):
self.assertEqual(outputs[1].dtype, tf.float32)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment