• Yu Cheng's avatar
    [Dev] Adjust computation logic to avoid precision loss when casting acc_s from... · e1d82bf3
    Yu Cheng authored
    [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16 (#141)
    
    - Remove redundant `acc_s_0` fragment in flash attention kernel
    - Simplify memory copy and reduction operations
    - Reorder memory copy and scaling steps for improved performance
    - Add Hopper-specific synchronization method in CUDA reduce template
    - Update reduce operation to use architecture-specific synchronization
    e1d82bf3
example_mla_decode.py 14.2 KB