[Feature] Optimize the PyTorch CUDA implementation for Criss Cross Attention (#1143)
* optimize criss cross attention
* optimize criss cross attention
* optimize criss cross attention
* fix lint
* fix ci, remove useless variable
* better ca_forward_kernel
Co-authored-by:
wondervictor <victorchanchina@gmail.com>
Showing
Please register or sign in to comment