Unverified Commit 3db6604e authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix core lib warning] Fix using `torch.tensor` for tensor conversion (#5028)

parent 9731e023
......@@ -341,7 +341,7 @@ def pack_padded_tensor(input, lengths):
def boolean_mask(input, mask):
if "bool" not in str(mask.dtype):
mask = th.tensor(mask, dtype=th.bool)
mask = th.as_tensor(mask, dtype=th.bool)
return input[mask]
......
......@@ -729,8 +729,8 @@ class MultiHeadAttention(nn.Module):
max_len_x = max(lengths_x)
max_len_mem = max(lengths_mem)
device = x.device
lengths_x = th.tensor(lengths_x, dtype=th.int64, device=device)
lengths_mem = th.tensor(lengths_mem, dtype=th.int64, device=device)
lengths_x = th.as_tensor(lengths_x, dtype=th.int64, device=device)
lengths_mem = th.as_tensor(lengths_mem, dtype=th.int64, device=device)
queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)
keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment