"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "66fa8ceaeaa6fe12f1bd4a5e6b0a924f59f715d9"
Unverified Commit d06c5a2a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1120 from CrafterKolyan/patch-3

Change attention mask dtype to be bool. Fix #1119
parents edc5222f 53282b5b
...@@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else: else:
mask_shift_len = qlen mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen) dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
else: else:
dec_attn_mask = torch.triu( dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
hids = [] hids = []
attentions = [] attentions = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment