• Mauro Bisson's avatar
    Optimized BWD kernel with the same changes for FWD from commit 8cb399ee: · 9a463332
    Mauro Bisson authored
    * Replaced PyTorch's slow permutation.
    * Split kernel into general and specialized versions (for num_channel <= 8192)
    * Enabled float4-based vectorized memory access, when possible.
    * Added runtime dispatch logic for kernel specialization.
    
    Aligned attention_fwd_cuda.cu with attention_bwd_cuda.cu in terms of naming conventions and kernel parameters.
    
    Extracted shared host/device functions and declarations into a separate module:
    * attention_utils.cuh
    * attention_utils.cu
    9a463332
attention_utils.cu 12.5 KB