1. 03 Nov, 2023 1 commit
    • Alp Dener's avatar
      [JAX] Regression tests for custom ops with jax.experimental.custom_partitioning (#471) · d20ba9fb
      Alp Dener authored
      
      
      [JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning.
      
      Coverage:
      - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
      - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
      - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
      - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
      - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
      Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
      d20ba9fb
  2. 31 Oct, 2023 1 commit
  3. 26 Oct, 2023 1 commit
  4. 24 Oct, 2023 3 commits
  5. 23 Oct, 2023 2 commits
  6. 20 Oct, 2023 4 commits
  7. 19 Oct, 2023 1 commit
  8. 17 Oct, 2023 2 commits
  9. 13 Oct, 2023 2 commits
  10. 12 Oct, 2023 2 commits
  11. 11 Oct, 2023 3 commits
  12. 10 Oct, 2023 2 commits
  13. 09 Oct, 2023 3 commits
  14. 06 Oct, 2023 2 commits
  15. 05 Oct, 2023 1 commit
  16. 04 Oct, 2023 2 commits
  17. 03 Oct, 2023 2 commits
  18. 02 Oct, 2023 1 commit
  19. 01 Oct, 2023 1 commit
  20. 27 Sep, 2023 4 commits