• Alp Dener's avatar
    [JAX] Regression tests for custom ops with jax.experimental.custom_partitioning (#471) · d20ba9fb
    Alp Dener authored
    
    
    [JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning.
    
    Coverage:
    - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
    - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
    - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
    - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
    - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
    Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
    d20ba9fb
test.sh 239 Bytes