Unverified Commit fab3b518 authored by ZOHETH's avatar ZOHETH Committed by GitHub
Browse files

fix deprecated tf method (#14671)

tf.matrix_band_part -> tf.linalg.band_part
parent 65b20b73
...@@ -497,12 +497,12 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -497,12 +497,12 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
""" """
attn_mask = tf.ones([qlen, qlen]) attn_mask = tf.ones([qlen, qlen])
mask_u = tf.matrix_band_part(attn_mask, 0, -1) mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0) mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen]) attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length: if self.same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0) mask_l = tf.linalg.band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment