Unverified Commit a3ec6a54 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

add building workflow for TE/Jax (#53)



* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* fix conflict due to PR62
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix c-extension-no-member and no-name-in-module

1. add transformer_engine_jax into extension-pkg-whitelist
2. convert pylintrc from CRLF to LF format
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint:disable and refactor import order
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d8a2f352
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment