[JAX] Add self_attn_mask_type and replace attn_type (#273)
* Add self_attn_mask_type and replace attn_type Signed-off-by:Reese Wang <rewang@nvidia.com> * Refine the keyword style for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace attn_type with attn_mask_type in praxis transformer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix typos Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
Please register or sign in to comment