[JAX] Use relative idx to ScaledUpperTriangMaskedSoftmaxFwdPrimitive (#523)
Use relative idx to ScaledUpperTriangMaskedSoftmaxFwdPrimitive.abstract to support batching.
Signed-off-by:
Ming Huang <mingh@nvidia.com>
Showing
Please register or sign in to comment