Compute the mask in-place, with less memory reads, and on CUDA on `XLNetLMHeadModel` (#23332)
When working on TorchInductor, I realised that there was a part from `XLNetLMHeadModel` that was being compiled to CPU code. This PR should allow to fuse this operation with other CUDA operations in `torch.compile`. It also should be faster on eager mode, as it has a this implementation has a lower foot-print. If in-place operations are not allowed even in non-grad context, I still believe that doing ones + tril rather than a ones + tril + zeros + cat should be faster simply due to the number of memory reads/writes. I tested that this code produces the same results for `0 <= qlen,mlen < 10` and `same_length in (True, False)`.
Showing
Please register or sign in to comment