"vscode:/vscode.git/clone" did not exist on "d9ffb87efb939cec920abd10a5843906d81b1815"
Unverified Commit 1d690289 authored by ZhuBaohe's avatar ZhuBaohe Committed by GitHub
Browse files

fix (#4410)

parent b86e42e0
......@@ -586,6 +586,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None:
# all mems can be attended to
if mlen > 0:
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
......@@ -598,6 +599,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
else:
......@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
else:
cat_ids = token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
......@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment