Optimize causal mask using torch.where (#2715)
* Optimize causal mask using torch.where Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance. * Maintain compatiblity with torch 1.0.0 - thanks for PR feedback * Fix typo * reformat line for CI
Showing
Please register or sign in to comment