Unverified Commit d4692ad1 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix dec_attn_mask in TFTransfoXLMainLayer (#15665)



* fix attn

* clean-up
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent b87c044c
...@@ -597,14 +597,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -597,14 +597,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None else 0 mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None else 0
klen = mlen + qlen klen = mlen + qlen
attn_mask = tf.ones([qlen, qlen]) # Compute decoder attention mask
mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen])
dec_attn_mask = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length:
mask_l = tf.linalg.band_part(attn_mask, -1, 0)
dec_attn_mask = tf.concat([dec_attn_mask[:, :qlen] + mask_l - mask_dia, dec_attn_mask[:, qlen:]], 1)
# ::: PyTorch masking code for reference ::: # ::: PyTorch masking code for reference :::
# if self.same_length: # if self.same_length:
# all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) # all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
...@@ -619,6 +613,21 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -619,6 +613,21 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# dec_attn_mask = torch.triu( # dec_attn_mask = torch.triu(
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] # word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
# TensorFlow version
dec_attn_mask = 1 - tf.linalg.band_part(
tf.ones([qlen, klen], dtype=tf.int32), -1, mlen
) # (q, q): diagonal with 1's
if self.same_length:
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
if mask_shift_len >= 1:
dec_attn_mask += 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), mask_shift_len - 1, -1)
else:
dec_attn_mask += tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, -mask_shift_len)
hids = [] hids = []
attentions = [] if inputs["output_attentions"] else None attentions = [] if inputs["output_attentions"] else None
if self.attn_type == 0: # default if self.attn_type == 0: # default
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment