Unverified Commit 12af02f2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix `NVTE_FRAMEWORK=all` installation (#1850)



* Fix NVTE_FRAMEWORK=all
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Workflow tests and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix jax install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update dep
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add numpy
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add dep
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 97e493f4
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -54,7 +54,6 @@ jobs: ...@@ -54,7 +54,6 @@ jobs:
NVTE_FRAMEWORK: pytorch NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1 MAX_JOBS: 1
- name: 'Sanity check' - name: 'Sanity check'
if: false # Sanity import test requires Flash Attention
run: python3 tests/pytorch/test_sanity_import.py run: python3 tests/pytorch/test_sanity_import.py
jax: jax:
name: 'JAX' name: 'JAX'
...@@ -73,4 +72,24 @@ jobs: ...@@ -73,4 +72,24 @@ jobs:
NVTE_FRAMEWORK: jax NVTE_FRAMEWORK: jax
MAX_JOBS: 1 MAX_JOBS: 1
- name: 'Sanity check' - name: 'Sanity check'
run: python tests/jax/test_sanity_import.py run: python3 tests/jax/test_sanity_import.py
all:
name: 'All'
runs-on: ubuntu-latest
container:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: pip install --no-build-isolation . -v
env:
NVTE_FRAMEWORK: all
MAX_JOBS: 1
- name: 'Sanity check'
run: python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "extensions.h" #include "../extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "transformer_engine/cudnn.h" #include "transformer_engine/cudnn.h"
#include "extensions.h" #include "../extensions.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
#include <memory> #include <memory>
#include "../extensions.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "extensions.h" #include "../extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h" #include "transformer_engine/recipe.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "transformer_engine/softmax.h" #include "transformer_engine/softmax.h"
#include "extensions.h" #include "../extensions.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace { namespace {
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "util.h" #include "util.h"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment