Commit dc8094a6 authored by yongshk's avatar yongshk
Browse files

Initial commit

parent 01153e93
...@@ -667,11 +667,19 @@ class MemTransformerLM(nn.Module): ...@@ -667,11 +667,19 @@ class MemTransformerLM(nn.Module):
mask_shift_len = qlen - mask_len mask_shift_len = qlen - mask_len
else: else:
mask_shift_len = qlen mask_shift_len = qlen
# 在 PyTorch 2.x 中,Byte 类型被废弃,不再作为 mask 的类型。取而代之的是 boolean 类型,这样可以更好地支持布尔运算。
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
dec_attn_mask = (torch.triu(all_ones, 1+mlen) dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
else: else:
# 在 PyTorch 2.x 中,Byte 类型被废弃,不再作为 mask 的类型。取而代之的是 boolean 类型,这样可以更好地支持布尔运算。
# dec_attn_mask = torch.triu(
# word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
dec_attn_mask = torch.triu( dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
hids = [] hids = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment