Unverified Commit 0fc402fb authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Use relative idx to ScaledUpperTriangMaskedSoftmaxFwdPrimitive (#523)



Use relative idx to ScaledUpperTriangMaskedSoftmaxFwdPrimitive.abstract to support batching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent d76118d9
...@@ -1482,8 +1482,8 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1482,8 +1482,8 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_upper_triang_masked_softmax_forward abstract te_scaled_upper_triang_masked_softmax_forward abstract
""" """
q_seqlen = logits_aval.shape[2] q_seqlen = logits_aval.shape[-2]
k_seqlen = logits_aval.shape[3] k_seqlen = logits_aval.shape[-1]
assert q_seqlen == k_seqlen assert q_seqlen == k_seqlen
return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor) return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment