add building workflow for TE/Jax (#53)
* add building workflow for jax modules Signed-off-by:Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * fix conflict due to PR62 Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix c-extension-no-member and no-name-in-module 1. add transformer_engine_jax into extension-pkg-whitelist 2. convert pylintrc from CRLF to LF format Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update setup.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint:disable and refactor import order Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
qa/L0_jax_unittest/test.sh
0 → 100644
tests/jax/test_helper.py
0 → 100644
tests/jax/test_sharding.py
0 → 100644
tests/jax/utils.py
0 → 100644
Please register or sign in to comment