-
zlsh80826 authored
* Add rng_state output for cross fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state and output for the flash attention backward Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias for the jax cross attn API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a minor bug Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias in the backward for the arbitrary fused attn backend Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
4d444db1