"vscode:/vscode.git/clone" did not exist on "3bc3ed636ba20063f7baf64bf7f993cde97146f2"
Commit 00e598f0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up internal seq2seqTransformer usage.

Add it to export. TODO: move the encoder and decoder layer out to layers/ or decoders/.

PiperOrigin-RevId: 335560860
parent 1407e274
......@@ -19,4 +19,5 @@ from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.xlnet import XLNetClassifier
......@@ -23,73 +23,12 @@ from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils
EOS_ID = 1
# pylint: disable=g-classes-have-attributes
def create_model(params, is_train):
"""Creates transformer model."""
encdec_kwargs = dict(
num_layers=params["num_hidden_layers"],
num_attention_heads=params["num_heads"],
intermediate_size=params["filter_size"],
activation="relu",
dropout_rate=params["relu_dropout"],
attention_dropout_rate=params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=params["relu_dropout"])
encoder_layer = TransformerEncoder(**encdec_kwargs)
decoder_layer = TransformerDecoder(**encdec_kwargs)
model_kwargs = dict(
vocab_size=params["vocab_size"],
embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
inputs = tf.keras.layers.Input((None,),
batch_size=batch_size,
dtype="int64",
name="inputs")
internal_model = Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
@tf.keras.utils.register_keras_serializable(package="Text")
class Seq2SeqTransformer(tf.keras.Model):
"""Transformer model with Keras.
......
......@@ -22,44 +22,10 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def test_create_model(self):
self.params = model_params.TINY_PARAMS
self.params["batch_size"] = 16
self.params["hidden_size"] = 12
self.params["num_hidden_layers"] = 2
self.params["filter_size"] = 14
self.params["num_heads"] = 2
self.params["vocab_size"] = 41
self.params["extra_decode_length"] = 2
self.params["beam_size"] = 3
self.params["dtype"] = tf.float32
model = seq2seq_transformer.create_model(self.params, is_train=True)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2)
self.assertLen(outputs, 1)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(inputs[1].shape.as_list(), [None, None])
self.assertEqual(inputs[1].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32)
model = seq2seq_transformer.create_model(self.params, is_train=False)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 1)
self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None])
self.assertEqual(outputs[0].dtype, tf.int32)
self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
......
......@@ -18,7 +18,8 @@ import numpy as np
import tensorflow as tf
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.modeling import models
from official.nlp.transformer import metrics
from official.nlp.transformer import model_params
from official.nlp.transformer import transformer
......@@ -34,6 +35,66 @@ def _count_params(layer, trainable_only=True):
]))
def _create_model(params, is_train):
"""Creates transformer model."""
encdec_kwargs = dict(
num_layers=params["num_hidden_layers"],
num_attention_heads=params["num_heads"],
intermediate_size=params["filter_size"],
activation="relu",
dropout_rate=params["relu_dropout"],
attention_dropout_rate=params["attention_dropout"],
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=params["relu_dropout"])
encoder_layer = models.TransformerEncoder(**encdec_kwargs)
decoder_layer = models.TransformerDecoder(**encdec_kwargs)
model_kwargs = dict(
vocab_size=params["vocab_size"],
embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
inputs = tf.keras.layers.Input((None,),
batch_size=batch_size,
dtype="int64",
name="inputs")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
class TransformerForwardTest(tf.test.TestCase):
def setUp(self):
......@@ -65,7 +126,7 @@ class TransformerForwardTest(tf.test.TestCase):
src_model_output = src_model([inputs, targets], training=True)
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, True)
dest_model = _create_model(self.params, True)
dest_num_weights = _count_params(dest_model)
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
......@@ -82,7 +143,7 @@ class TransformerForwardTest(tf.test.TestCase):
src_model_output = src_model([inputs], training=False)
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, False)
dest_model = _create_model(self.params, False)
dest_num_weights = _count_params(dest_model)
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment