Unverified Commit 1d690289 authored by ZhuBaohe's avatar ZhuBaohe Committed by GitHub
Browse files

fix (#4410)

parent b86e42e0
...@@ -586,8 +586,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -586,8 +586,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None: if data_mask is not None:
# all mems can be attended to # all mems can be attended to
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float) if mlen > 0:
data_mask = tf.concat([mems_mask, data_mask], axis=1) mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None: if attn_mask is None:
attn_mask = data_mask[:, :, :, None] attn_mask = data_mask[:, :, :, None]
else: else:
...@@ -598,7 +599,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -598,7 +599,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if attn_mask is not None: if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float) non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float) non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
else: else:
non_tgt_mask = None non_tgt_mask = None
...@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Segment embedding # Segment embedding
if token_type_ids is not None: if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) if mlen > 0:
cat_ids = tf.concat([mem_pad, token_type_ids], 0) mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
else:
cat_ids = token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
...@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: raise NotImplementedError
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment