Commit 1eb125fb authored by thomwolf's avatar thomwolf
Browse files

be sure we have uint8

parent 7fba47b7
......@@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
......@@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids = []
attentions = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment