Commit 662a40be authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up internal seq2seqTransformer usage.

Add it to export. TODO: move the encoder and decoder layer out to layers/ or decoders/.

PiperOrigin-RevId: 335560860
parent e25fe5dd
......@@ -19,4 +19,5 @@ from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier
......@@ -23,73 +23,12 @@ from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils
EOS_ID = 1
# pylint: disable=g-classes-have-attributes
def create_model(params, is_train):
"""Creates transformer model."""
encdec_kwargs = dict(
num_layers=params["num_hidden_layers"],
num_attention_heads=params["num_heads"],
intermediate_size=params["filter_size"],
activation="relu",
dropout_rate=params["relu_dropout"],
attention_dropout_rate=params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=params["relu_dropout"])
encoder_layer = TransformerEncoder(**encdec_kwargs)
decoder_layer = TransformerDecoder(**encdec_kwargs)
model_kwargs = dict(
vocab_size=params["vocab_size"],
embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
inputs = tf.keras.layers.Input((None,),
batch_size=batch_size,
dtype="int64",
name="inputs")
internal_model = Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
......
......@@ -22,44 +22,10 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def test_create_model(self):
self.params = model_params.TINY_PARAMS
self.params["batch_size"] = 16
self.params["hidden_size"] = 12
self.params["num_hidden_layers"] = 2
self.params["filter_size"] = 14
self.params["num_heads"] = 2
self.params["vocab_size"] = 41
self.params["extra_decode_length"] = 2
self.params["beam_size"] = 3
self.params["dtype"] = tf.float32
model = seq2seq_transformer.create_model(self.params, is_train=True)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2)
self.assertLen(outputs, 1)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(inputs[1].shape.as_list(), [None, None])
self.assertEqual(inputs[1].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32)
model = seq2seq_transformer.create_model(self.params, is_train=False)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 1)
self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None])
self.assertEqual(outputs[0].dtype, tf.int32)
self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
......
......@@ -18,7 +18,8 @@ import numpy as np
import tensorflow as tf
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.modeling import models
from official.nlp.transformer import metrics
from official.nlp.transformer import model_params
from official.nlp.transformer import transformer
......@@ -34,6 +35,66 @@ def _count_params(layer, trainable_only=True):
]))
def _create_model(params, is_train):
"""Creates transformer model."""
encdec_kwargs = dict(
num_layers=params["num_hidden_layers"],
num_attention_heads=params["num_heads"],
intermediate_size=params["filter_size"],
activation="relu",
dropout_rate=params["relu_dropout"],
attention_dropout_rate=params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=params["relu_dropout"])
encoder_layer = models.TransformerEncoder(**encdec_kwargs)
decoder_layer = models.TransformerDecoder(**encdec_kwargs)
model_kwargs = dict(
vocab_size=params["vocab_size"],
embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
inputs = tf.keras.layers.Input((None,),
batch_size=batch_size,
dtype="int64",
name="inputs")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
class TransformerForwardTest(tf.test.TestCase):
def setUp(self):
......@@ -65,7 +126,7 @@ class TransformerForwardTest(tf.test.TestCase):
src_model_output = src_model([inputs, targets], training=True)
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, True)
dest_model = _create_model(self.params, True)
dest_num_weights = _count_params(dest_model)
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
......@@ -82,7 +143,7 @@ class TransformerForwardTest(tf.test.TestCase):
src_model_output = src_model([inputs], training=False)
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, False)
dest_model = _create_model(self.params, False)
dest_num_weights = _count_params(dest_model)
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment