Unverified Commit 12af02f2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix `NVTE_FRAMEWORK=all` installation (#1850)



* Fix NVTE_FRAMEWORK=all
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Workflow tests and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix jax install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update dep
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add numpy
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add dep
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 97e493f4
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -54,7 +54,6 @@ jobs:
NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1
- name: 'Sanity check'
if: false # Sanity import test requires Flash Attention
run: python3 tests/pytorch/test_sanity_import.py
jax:
name: 'JAX'
......@@ -73,4 +72,24 @@ jobs:
NVTE_FRAMEWORK: jax
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
run: python3 tests/jax/test_sanity_import.py
all:
name: 'All'
runs-on: ubuntu-latest
container:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: pip install --no-build-isolation . -v
env:
NVTE_FRAMEWORK: all
MAX_JOBS: 1
- name: 'Sanity check'
run: python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py
......@@ -7,7 +7,7 @@
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
......
......@@ -6,7 +6,7 @@
#include "transformer_engine/cudnn.h"
#include "extensions.h"
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......
......@@ -7,9 +7,9 @@
#include <memory>
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -7,7 +7,7 @@
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -5,7 +5,7 @@
************************************************************************/
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "xla/ffi/api/c_api.h"
......
......@@ -6,7 +6,7 @@
#include "transformer_engine/softmax.h"
#include "extensions.h"
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
......
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
namespace {
......
......@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -10,10 +10,10 @@
#include <string>
#include "../common.h"
#include "../extensions.h"
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment