Commit 58e805e0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move TransformerDecoderLayer to modeling/

PiperOrigin-RevId: 317330705
parent 7d210ec0
...@@ -9,13 +9,17 @@ assemble new layers, networks, or models. ...@@ -9,13 +9,17 @@ assemble new layers, networks, or models.
initialization parameters. initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention. `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache * [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding. used for auto-agressive decoding.
* [MultiChannelAttention](multi_channel_attention.py) implements an variant of
multi-head attention which can be used to merge multiple streams for
cross-attentions.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking * [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436). ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
......
...@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Multi-channel decoder.""" """Multi-channel Attention."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -24,11 +25,25 @@ import math ...@@ -24,11 +25,25 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
class DocAttention(tf.keras.layers.Layer):
"""Documents Attention layer."""
class VotingAttention(tf.keras.layers.Layer):
"""Voting Attention layer.
Arguments:
num_heads: the number of attention heads.
head_size: per-head hidden size.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self, def __init__(self,
num_heads, num_heads,
...@@ -41,7 +56,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -41,7 +56,7 @@ class DocAttention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(DocAttention, self).__init__(**kwargs) super(VotingAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -52,7 +67,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -52,7 +67,7 @@ class DocAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = layers.DenseEinsum( self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -63,7 +78,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -63,7 +78,7 @@ class DocAttention(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype, dtype=self.dtype,
name="encdocatt_query") name="encdocatt_query")
self._key_dense = layers.DenseEinsum( self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -74,7 +89,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -74,7 +89,7 @@ class DocAttention(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype, dtype=self.dtype,
name="encdocatt_key") name="encdocatt_key")
super(DocAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
...@@ -95,12 +110,16 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -95,12 +110,16 @@ class DocAttention(tf.keras.layers.Layer):
return tf.nn.softmax(doc_attention_probs + infadder) return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(layers.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.""" """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
cross-attention target sequences.
"""
def build(self, input_shape): def build(self, input_shape):
super(MultiChannelAttention, self).build(input_shape) super(MultiChannelAttention, self).build(input_shape)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self, inputs, attention_mask=None):
from_tensor = inputs[0] from_tensor = inputs[0]
......
...@@ -22,14 +22,15 @@ from __future__ import print_function ...@@ -22,14 +22,15 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.nhnet import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
class MultiChannelAttentionTest(tf.test.TestCase): class MultiChannelAttentionTest(tf.test.TestCase):
def test_doc_attention(self): def test_doc_attention(self):
num_heads = 2 num_heads = 2
doc_attention = multi_channel_attention.DocAttention(num_heads, head_size=8) doc_attention = multi_channel_attention.VotingAttention(
num_heads, head_size=8)
num_docs = 3 num_docs = 3
inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32) inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
doc_mask = np.zeros((2, num_docs), dtype=np.float32) doc_mask = np.zeros((2, num_docs), dtype=np.float32)
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager from official.nlp.modeling.layers.util import tf_function_if_eager
...@@ -236,3 +237,145 @@ class CompiledTransformer(Transformer): ...@@ -236,3 +237,145 @@ class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True) @tf_function_if_eager(experimental_compile=True)
def call(self, inputs): def call(self, inputs):
return super(CompiledTransformer, self).call(inputs) return super(CompiledTransformer, self).call(inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="relu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
multi_channel_cross_attention=False,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
self._bias_initializer = tf.keras.initializers.get("zeros")
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = attention.MultiHeadAttention
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, input_shape):
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="self_attention")
self.self_attention_output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="self_attention_output")
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection.
self.intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self.intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderLayer, self).build(input_shape)
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention(
self_attention_inputs,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory]
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
...@@ -215,5 +215,41 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -215,5 +215,41 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
'value':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(2, 0, num_attention_heads,
hidden_size // num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -22,151 +22,10 @@ from __future__ import print_function ...@@ -22,151 +22,10 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.nhnet import multi_channel_attention from official.nlp.modeling.layers import transformer
from official.nlp.transformer import model_utils as transformer_utils from official.nlp.transformer import model_utils as transformer_utils
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
multi_channel_cross_attention=False,
**kwargs):
super(TransformerDecoderBlock, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf_utils.get_activation(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
self._bias_initializer = tf.keras.initializers.get("zeros")
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = layers.MultiHeadAttention
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, input_shape):
# Self attention.
self.self_attention = layers.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="self_attention")
self.self_attention_output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="self_attention_output")
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection.
self.intermediate_dense = layers.DenseEinsum(
output_shape=self.intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderBlock, self).build(input_shape)
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention(
self_attention_inputs,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory]
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
class TransformerDecoder(tf.keras.layers.Layer): class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder stack.""" """Transformer decoder stack."""
...@@ -200,7 +59,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -200,7 +59,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
self.layers = [] self.layers = []
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append( self.layers.append(
TransformerDecoderBlock( transformer.TransformerDecoderLayer(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size, intermediate_size=self.intermediate_size,
......
...@@ -26,17 +26,6 @@ from official.nlp.nhnet import decoder ...@@ -26,17 +26,6 @@ from official.nlp.nhnet import decoder
from official.nlp.nhnet import utils from official.nlp.nhnet import utils
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
"key":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
"value":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
class DecoderTest(tf.test.TestCase): class DecoderTest(tf.test.TestCase):
def setUp(self): def setUp(self):
...@@ -56,26 +45,6 @@ class DecoderTest(tf.test.TestCase): ...@@ -56,26 +45,6 @@ class DecoderTest(tf.test.TestCase):
decoder_block.build(None) decoder_block.build(None)
self.assertEqual(len(decoder_block.layers), self._config.num_hidden_layers) self.assertEqual(len(decoder_block.layers), self._config.num_hidden_layers)
def test_decoder_block_with_cache(self):
decoder_block = decoder.TransformerDecoderBlock(
hidden_size=self._config.hidden_size,
num_attention_heads=self._config.num_attention_heads,
intermediate_size=self._config.intermediate_size,
intermediate_activation=self._config.hidden_act,
hidden_dropout_prob=self._config.hidden_dropout_prob,
attention_probs_dropout_prob=self._config.attention_probs_dropout_prob,
initializer_range=self._config.initializer_range)
# Forward path.
dummy_tensor = tf.zeros([2, 4, self._config.hidden_size], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(
2, 0, self._config.num_attention_heads,
self._config.hidden_size // self._config.num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, self._config.hidden_size))
self.assertEqual(cache["value"].shape, (2, 4, 2, 8))
def test_bert_decoder(self): def test_bert_decoder(self):
seq_length = 10 seq_length = 10
encoder_input_ids = tf.keras.layers.Input( encoder_input_ids = tf.keras.layers.Input(
......
...@@ -27,9 +27,9 @@ from typing import Optional, Text ...@@ -27,9 +27,9 @@ from typing import Optional, Text
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.nhnet import configs from official.nlp.nhnet import configs
from official.nlp.nhnet import decoder from official.nlp.nhnet import decoder
from official.nlp.nhnet import multi_channel_attention
from official.nlp.nhnet import utils from official.nlp.nhnet import utils
from official.nlp.transformer import beam_search from official.nlp.transformer import beam_search
...@@ -273,7 +273,7 @@ class NHNet(Bert2Bert): ...@@ -273,7 +273,7 @@ class NHNet(Bert2Bert):
def __init__(self, params, bert_layer, decoder_layer, name=None): def __init__(self, params, bert_layer, decoder_layer, name=None):
super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name) super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name)
self.doc_attention = multi_channel_attention.DocAttention( self.doc_attention = multi_channel_attention.VotingAttention(
num_heads=params.num_decoder_attn_heads, num_heads=params.num_decoder_attn_heads,
head_size=params.hidden_size // params.num_decoder_attn_heads) head_size=params.hidden_size // params.num_decoder_attn_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment