[Example] Add support for `bfloat16` and user-defined `sm_scale` in attention sink examples (#924)
* revert split+sum template for MHA backward
* lint
* Update example_mha_bwd.py
* Update example_mha_bwd_wgmma_pipelined.py
* Refactor attention sink examples to support bf16 and user-defined softmax scale
* fix typos
* Adding compile flags for fast math optimizations and enabling BF16 support in both GQA and MHA backward implementations.
* Update backward configuration for GQA and MHA examples to align with flash attention
* Refactor GQA backward implementation to improve atomic add performance
* Allow for slightly larger numerical error for bf16
* upd readme to show bf16 benchmark results
* lint
* fix ci and lint
* fix comments and lint
* refactor atomic add
---------
Co-authored-by:
Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Showing
Please register or sign in to comment