Commit 0be6a2a6 authored by thomwolf's avatar thomwolf
Browse files

be sure we have uint8

parent 38b79b5a
...@@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mlen = mems[0].size(0) if mems is not None else 0 mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
if self.same_length: if self.same_length:
all_ones = word_emb.new_ones(qlen, klen) all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
mask_len = klen - self.mem_len mask_len = klen - self.mem_len
if mask_len > 0: if mask_len > 0:
mask_shift_len = qlen - mask_len mask_shift_len = qlen - mask_len
...@@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else: else:
dec_attn_mask = torch.triu( dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None] word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids = [] hids = []
attentions = [] attentions = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment