Commit 89988364 authored by Simon Layton's avatar Simon Layton
Browse files

Fix test fails and warnings

Attention output was in bnij ordering instead of ijbn which everything
else will expect. This was an oversight on my part, and keeps the
attention inputs/outputs identical to the original code.

Also moved back from tensor slicing to index_select in rel_shift_bnij to
make the tracer happy.
parent 9ffda216
...@@ -247,8 +247,10 @@ class XLNetRelativeAttention(nn.Module): ...@@ -247,8 +247,10 @@ class XLNetRelativeAttention(nn.Module):
x = x[:, :, 1:, :] x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1) x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1)
# Note: the tensor-slice form was faster in my testing than torch.index_select # Note: the tensor-slice form was faster in my testing than torch.index_select
# x = torch.index_select(x, 2, torch.arange(klen, device=x.device, dtype=torch.long)) # However, tracing doesn't like the nature of the slice, and if klen changes
x = x[:, :, :, :klen] # during the run then it'll fail, whereas index_select will be fine.
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x return x
...@@ -290,7 +292,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -290,7 +292,7 @@ class XLNetRelativeAttention(nn.Module):
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h) attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
if self.output_attentions: if self.output_attentions:
return attn_vec, attn_prob return attn_vec, torch.einsum('bnij->ijbn', attn_prob)
return attn_vec return attn_vec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment