Commit 42c3b8f0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 279834138
parent 0c5919ef
......@@ -69,12 +69,12 @@ def _get_initializer(flags):
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
"""Creates attention mask when single-side context allowed only."""
attn_mask = tf.ones([qlen, qlen], dtype=dtype)
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
mask_l = tf.linalg.band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret
......@@ -872,7 +872,7 @@ class PretrainingXLNetModel(tf.keras.Model):
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
lm_loss = self.lmloss_layer(
lm_loss, _ = self.lmloss_layer(
hidden=transformerxl_output,
target=target,
lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
......@@ -1033,7 +1033,7 @@ class LMLossLayer(tf.keras.layers.Layer):
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
return total_loss
return total_loss, logits
class Summarization(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment