Commit 5a1bce51 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Replace Transformer in seq2seqTransformer.

PiperOrigin-RevId: 328798553
parent dd0126f9
......@@ -20,6 +20,7 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp.layers import transformer_encoder_block
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import metrics
......@@ -471,16 +472,16 @@ class TransformerEncoder(tf.keras.layers.Layer):
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
layers.Transformer(
transformer_encoder_block.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
inner_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment