"...composable_kernel.git" did not exist on "d2b1ed1bc592ed39aee2d8751c2a6d9ad911de03"
[JAX] Prepare cross flash attention (#525)
* Add rng_state output for cross fused attention Signed-off-by:Reese Wang <rewang@nvidia.com> * Add rng_state and output for the flash attention backward Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias for the jax cross attn API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a minor bug Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias in the backward for the arbitrary fused attn backend Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
Showing
Please register or sign in to comment