[JAX] Add fast path for causal masking with segment IDs. (#1601)
Add fast path for causal masking with segment IDs.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>
Showing
Please register or sign in to comment
Add fast path for causal masking with segment IDs.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>