Commit 658a45f2 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Multichannel attention: Override _build_attention method instead of build()

PiperOrigin-RevId: 318208409
parent 9e2050a6
......@@ -117,8 +117,8 @@ class MultiChannelAttention(attention.MultiHeadAttention):
cross-attention target sequences.
"""
def build(self, input_shape):
super(MultiChannelAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment