- 12 May, 2025 1 commit
-
-
jberchtold-nvidia authored
This reverts commit 5bee81e2 . Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 05 May, 2025 2 commits
-
-
Phuong Nguyen authored
* removes unneccessary reshapes for FP8 GEMM * use nn.jax.scaled_matmul Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
jberchtold-nvidia authored
* Enforce input sharding of norm primitive does not shard hidden dim Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix partitioning issue in dact primitive causing NaN and add better shape checks before calling TE API Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Move dact shape assertion from cpp to python Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 01 May, 2025 1 commit
-
-
Phuong Nguyen authored
* exclude GroupedGemm APIs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 30 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
Fix distributed layernorm test failure Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com>
-
- 29 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
* Update test_helper.py and add QuantizeConfig class for CurrentScaling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * WIP distributed current scaling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Distributed Current Scaling (debugging). Distributed implementation with replicated scale_inv works for layernorm_mlp but feels like a hack Has different per-device scale_inv values, but jax.debug.print only shows one of them. Since we're telling JAX/XLA that this scale is replicated, I think it assumes all the values are equal. However, it doesn't actually check this, so it seems we are able to get away with per-device scales for current scaling but I am not sure how stable this will be and may randomly fail if us or the user changes partitioning at all or if XLA decides to actually act on the assumption that all these scale_invs are the same. Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Implement distributed current scaling by computing a global amax and scale before quantization Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add encoder and mnist tests for current scaling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add primitive prefix to shardy unique_vars to prevent factor conflicts when performing unfused primitives for current scaling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove scale_shape primitive arg that is no longer used Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Format Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix expected result on multiprocessing encoder test Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint fix Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update multiprocessing current scaling tolerances Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Uncomment test case that was disabled for testing Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove commented out debug line Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 22 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
* [JAX-Q] Single GPU current scaling for JAX Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix scale check dtype for MXFP8 scales affecting tests using assert_bitwise_scaled_tensors Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Address comments Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove cast to fp32 for norm primitives now that zero-centered gamma dtype issue is fixed Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix lint issue Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove unnecessary cast to fp32 Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 21 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
Check CuDNN version and apply unfused norm if below a version with the fix Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com>
-
- 17 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
* Add a flag to support computing zero-centered gamma in weight dtype or compute dtype for CuDNN Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Address comments Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 16 Apr, 2025 1 commit
-
-
Kshitij Lakhani authored
* Add test cases for full coverage in jax/test_layer.py - causal and window size None - causal and window size default (-1,1) - no_mask and window size default (-1,1) - no_mask and window size default (2,2) - padding and window size None - padding_causal and window_size (2,2) Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Correct the condition where padding_causal_mask was being mapped to scaled upper triangle Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Fix Issue #1524 Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Add a runner and test cases for jax.flax.module.Softmax class for fwd pass only Segregate runner classes for Softmax module and softmax primitives Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Simplify logic when picking softmax primitives and softmax jax framework calls Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify the logic for performing jax based softmax Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Code clean up Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support table for mask, SWA and Softmax type. Code linting Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Explicit SWA conditons in comments. Fix Typo Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve typo to remove None in SWA comments section Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 14 Apr, 2025 2 commits
-
-
Johannes Reifferscheid authored
* Add experimental Shardy support. Production use is not yet recommended. --------- Signed-off-by:Johannes Reifferscheid <jreiffers@nvidia.com>
-
Hua Huang authored
* New GroupedGemmPrimitive using variadic args * Remove squeeze() to reduce D2D memcpy * Revert to the list append fashion to simplify code --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 09 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* scaling enum abstract * rm NVTE_ from ScalingMode names * rework scaling mode enum in grouped gemm * fix norm sharding --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 07 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* rm no scaling enum Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * update jax enum Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 04 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 01 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor + mxfp8 * added grouped gemm * rename linear to dense * added cublas init phase for groupedGemm * relax the tol of test encoder multiprocessing mxfp8 by 0.001 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 18 Mar, 2025 1 commit
-
-
Michael Goldfarb authored
* Fix softmax shape for THD format. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 13 Mar, 2025 1 commit
-
-
Reese Wang authored
Make ffi compatible with jax 0.4 Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Mar, 2025 1 commit
-
-
Reese Wang authored
Remove xla_ignore_channel_id check and ignore Scan loop warning in unit test Signed-off-by:Reese Wang <rewang@nvidia.com>
-
- 05 Mar, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
* Fix wheel install after src install Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix JAX imports Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * switch order of dirs for finding so Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Use existing dir src build Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 03 Mar, 2025 1 commit
-
-
Reese Wang authored
* Support THD + ring attention for self attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidate reorder strategy Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix dataclass frozen issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine P2P helper check_supported Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add segment_ids/pos check Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fixup Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dual chunk swap example Signed-off-by:
Reese Wang <rewang@nvidia.com> * Align different reorder code structure Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 14 Feb, 2025 2 commits
-
-
Reese Wang authored
* Expose THD to flex MHA module Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
Phuong Nguyen authored
* fixes L1 test * fix test_multigpu_encoder * fixes for other multi-encoder tests * jax.extend.ffi to jax.ffi * initialization with float32 * add init_dtype as an optional arg to all modules * update use_scan query from xla flags * relax threshold for test_encoder fp8 * relax the tols --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 11 Feb, 2025 1 commit
-
-
Phuong Nguyen authored
* flax module to init params with given dtype Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * all tests passed Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * remove unneccessary reshape for kernel Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * remove casting output of dot Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Jan, 2025 1 commit
-
-
Reese Wang authored
* POC for segment_ids/segment_pos Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change segment_pos position Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use RemainingArgs to solve number of parameters mismatches Signed-off-by:
Reese Wang <rewang@nvidia.com> * Test mask_descriptor for accomendating different mask representations Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use descriptor in bwd Signed-off-by:
Reese Wang <rewang@nvidia.com> * Primitives only accepts pure jnp array Signed-off-by:
Reese Wang <rewang@nvidia.com> * segment_ids/pos support POC Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move seqlens/offsets generation to mask descriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename MaskDescriptor to SequenceDescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize get_seqlens_and_offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Utilize sequence desc on FA bwd Signed-off-by:
Reese Wang <rewang@nvidia.com> * Migrate to new API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add docstrings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove small inputs and test different input format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix seed shardings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Optimize sequence converting overhead Signed-off-by:
Reese Wang <rewang@nvidia.com> * Optimize seq_offsets calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix up Signed-off-by:
Reese Wang <rewang@nvidia.com> * fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflicts Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove reduntant line Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Jan, 2025 1 commit
-
-
Michael Goldfarb authored
Correct fused attention output after each step to reduce intermediate memory use. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 17 Dec, 2024 1 commit
-
-
Reese Wang authored
* Add util functions to attn_mask_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add util functions to qkv_layout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix THD cross reference code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add comment for make_mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add doc strings for the added functions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cache for fa deterministic which causes UT failed Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename fixture to avoid conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* softmax custom calls with correct encapsulates * rm jax deprecated features --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* cuDNN normalization integration * TE Norm refactor * TE Norm APIs changes. --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 12 Nov, 2024 1 commit
-
-
Hua Huang authored
* FFI for all softmax functions Signed-off-by:
Hua Huang <huah@nvidia.com> * FFI for FusedAttnBackward and Dequantize FusedAttnBackward passed all testes in test_fused_attn.py. Dequantize is not used currently; finish it for completeness. Signed-off-by:
Hua Huang <huah@nvidia.com> * Fix FusedAttnBackward FFI pybind & simplify Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes to tests/jax/test_fused_attn.py Signed-off-by:
Hua Huang <huah@nvidia.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 11 Nov, 2024 1 commit
-
-
Ming-Xu Huang authored
* Implement ring attention primative for Jax. Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 07 Nov, 2024 1 commit
-
-
Phuong Nguyen authored
* added prepare phase for the FusedAttnForwardFFI * enabled FusedAttnForwardFFI by default * moved prepare phase into pybind --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Nov, 2024 1 commit
-
-
Hua Huang authored
* FFI for some transpose & activation functions Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 04 Nov, 2024 1 commit
-
-
Hua Huang authored
* Add LayerNormForwardFFI(); add FFI calls in Python Signed-off-by:
Hua Huang <huah@nvidia.com> * Add FFI for RMS norm, all tests passed Signed-off-by:
Hua Huang <huah@nvidia.com> * Simplify layer & RMS norm FFI calls Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify tensor size calculations Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 01 Nov, 2024 1 commit
-
-
Phuong Nguyen authored
rm default value for NVTE_JAX_FUSED_ATTN_WITH_FFI Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 31 Oct, 2024 2 commits
-
-
Phuong Nguyen authored
* disable fused attn with ffi --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
Phuong Nguyen authored
* lowering a dict of attrs * improve err message with line and func info * implement a product() for ffi dimensions --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 25 Oct, 2024 1 commit
-
-
Charlene Yang authored
* WIP: add max_t support for THD Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: save tensors for debug and point to new FE Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix stats in bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix stats in fwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add docstring for DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: first try on adding max_b and max_t Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit c3d522e9f5aef3c8ddfec5bf6ff24c3db97bb059. Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "WIP: first try on adding max_b and max_t" This reverts commit 3bc01ebaf2aa846fd16634e2d33b0d0f5803a076. Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update docstring and fix max_seqlen logic for thd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert two lines of change in docstring Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: add get_max_b/t Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix max_seqlen code and docstring Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * sucess: add max_b/max_t Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug code Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change max_b/max_t buckets Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix b vs orig_b Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix b vs orig_b with 0 fill Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE for T3HD/TH3D Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add max_b to conversion kernels Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix changes after last merge Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add Jax support for max_t Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.8.0-rc Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 1.8.0 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * code review/formating fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Stats shape for <9.6 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return nullptr for offset_stats when cudnn < 9.6 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more version control Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-