- 18 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* Move test distributed encoder to L0 distributed test suit --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Reese Wang <rewang@nvidia.com>
-
- 20 Feb, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 01 Feb, 2024 1 commit
-
-
zlsh80826 authored
* Fix unfused GQA perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove WAR for Check failed: reduction_kind.has_value() Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 16 Jan, 2024 1 commit
-
-
zlsh80826 authored
* Support num_gqa_groups arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA support on the JAX bridge code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the kv stride of the arbitrary backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support unfused GQA Signed-off-by:
Reese Wang <rewang@nvidia.com> * Calculate seqlen before the primitive for the better perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for te_jax Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add num_gqa_groups doc Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the qkv_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Correct the variable naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle Max512 CAUSAL Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add WAR for the latest jax image Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 20 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Remove assertion for NO_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix JAX distributed unit tests name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 03 Nov, 2023 1 commit
-
-
Alp Dener authored
[JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning. Coverage: - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL Signed-off-by:Alp Dener <adener@nvidia.com>
-