Optimized FWD kernel: custom permutations, gmem accesses reduction, vectorized access
* Replaced PyTorch's slow permutation ops with custom kernels, significantly improving performance (especially on GB200). * Split kernel into general and specialized versions for num_channel <= 16384, significantly reducing memory accesses. * Enabled float4-based vectorized memory access when pointer alignment and channel size allow, improving throughput. * Added runtime dispatch logic for kernel specialization.
Showing
This diff is collapsed.