"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "dafd924c1f165b4478a9d7a3c915d2ecc2e148e2"
[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning (#471)
[JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning.
Coverage:
- layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
- rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
- softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
- self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
- cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
Signed-off-by:
Alp Dener <adener@nvidia.com>
Showing
Please register or sign in to comment