Commit 42c3b8f0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 279834138
parent 0c5919ef
...@@ -69,12 +69,12 @@ def _get_initializer(flags): ...@@ -69,12 +69,12 @@ def _get_initializer(flags):
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
"""Creates attention mask when single-side context allowed only.""" """Creates attention mask when single-side context allowed only."""
attn_mask = tf.ones([qlen, qlen], dtype=dtype) attn_mask = tf.ones([qlen, qlen], dtype=dtype)
mask_u = tf.matrix_band_part(attn_mask, 0, -1) mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0) mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length: if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0) mask_l = tf.linalg.band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret return ret
...@@ -872,7 +872,7 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -872,7 +872,7 @@ class PretrainingXLNetModel(tf.keras.Model):
perm_mask=perm_mask, perm_mask=perm_mask,
target_mapping=target_mapping, target_mapping=target_mapping,
inp_q=inp_q) inp_q=inp_q)
lm_loss = self.lmloss_layer( lm_loss, _ = self.lmloss_layer(
hidden=transformerxl_output, hidden=transformerxl_output,
target=target, target=target,
lookup_table=self.transformerxl_model.embedding_lookup.lookup_table, lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
...@@ -1033,7 +1033,7 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1033,7 +1033,7 @@ class LMLossLayer(tf.keras.layers.Layer):
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask) total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
return total_loss return total_loss, logits
class Summarization(tf.keras.layers.Layer): class Summarization(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment