[JAX] Expose THD format to the flax module (#1480)
* Expose THD to flex MHA module Signed-off-by:Reese Wang <rewang@nvidia.com> * Enhance docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment