Unverified Commit 7f8b9091 authored by Mario Lezcano Casado's avatar Mario Lezcano Casado Committed by GitHub
Browse files

Compute the mask in-place, with less memory reads, and on CUDA on `XLNetLMHeadModel` (#23332)

When working on TorchInductor, I realised that there was a part from
`XLNetLMHeadModel` that was being compiled to CPU code.

This PR should allow to fuse this operation with other CUDA operations
in `torch.compile`. It also should be faster on eager mode, as it has a
this implementation has a lower foot-print.

If in-place operations are not allowed even in non-grad context, I still
believe that doing ones + tril rather than a ones + tril + zeros + cat
should be faster simply due to the number of memory reads/writes.

I tested that this code produces the same results for `0 <= qlen,mlen <
10` and `same_length in (True, False)`.
parent 8c8744a9
...@@ -976,16 +976,15 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -976,16 +976,15 @@ class XLNetModel(XLNetPreTrainedModel):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0] v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
""" """
attn_mask = torch.ones([qlen, qlen]) mask = torch.ones(qlen, qlen + mlen, self.device)
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
if self.same_length: if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1) mask_lo = mask[:, :qlen].tril(-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1) mask.triu_(mlen + 1)
mask[:, :qlen] += mask_lo
else:
mask.triu_(mlen + 1)
ret = ret.to(self.device) return mask
return ret
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
# cache hidden states into memory. # cache hidden states into memory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment