[JAX] Added prepare phase for the FusedAttnForwardFFI (#1313)
* added prepare phase for the FusedAttnForwardFFI
* enabled FusedAttnForwardFFI by default
* moved prepare phase into pybind
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment