Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
35e687d0
Unverified
Commit
35e687d0
authored
Oct 13, 2023
by
Tim Moon
Committed by
GitHub
Oct 13, 2023
Browse files
Remove remaining references to TensorFlow (#474)
Signed-off-by:
Tim Moon
<
tmoon@nvidia.com
>
parent
8e757a45
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
22 deletions
+16
-22
.github/workflows/build.yml
.github/workflows/build.yml
+2
-5
.github/workflows/lint.yml
.github/workflows/lint.yml
+1
-4
README.rst
README.rst
+12
-12
setup.py
setup.py
+1
-1
No files found.
.github/workflows/build.yml
View file @
35e687d0
...
@@ -31,8 +31,7 @@ jobs:
...
@@ -31,8 +31,7 @@ jobs:
name
:
'
JAX'
name
:
'
JAX'
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
container
:
container
:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image
:
ghcr.io/nvidia/jax:latest
image
:
nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options
:
--user root
options
:
--user root
steps
:
steps
:
-
name
:
'
Checkout'
-
name
:
'
Checkout'
...
@@ -40,9 +39,7 @@ jobs:
...
@@ -40,9 +39,7 @@ jobs:
with
:
with
:
submodules
:
recursive
submodules
:
recursive
-
name
:
'
Build'
-
name
:
'
Build'
run
:
|
run
:
pip install . -v
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip install . -v
env
:
env
:
NVTE_FRAMEWORK
:
jax
NVTE_FRAMEWORK
:
jax
-
name
:
'
Sanity
check'
-
name
:
'
Sanity
check'
...
...
.github/workflows/lint.yml
View file @
35e687d0
...
@@ -50,16 +50,13 @@ jobs:
...
@@ -50,16 +50,13 @@ jobs:
name
:
'
JAX
Python'
name
:
'
JAX
Python'
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
container
:
container
:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image
:
ghcr.io/nvidia/jax:latest
image
:
nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options
:
--user root
options
:
--user root
steps
:
steps
:
-
name
:
'
Checkout'
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
uses
:
actions/checkout@v3
-
name
:
'
Lint'
-
name
:
'
Lint'
run
:
|
run
:
|
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax
export PYTHON_ONLY=1
export PYTHON_ONLY=1
export TE_PATH=.
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
bash ./qa/L0_jax_lint/test.sh
README.rst
View file @
35e687d0
...
@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
...
@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
the
`
TransformerLayer
`
API
of
Transformer
Engine
is
flexible
enough
to
build
multiple
major
the
`
TransformerLayer
`
API
of
Transformer
Engine
is
flexible
enough
to
build
multiple
major
Transformer
model
architectures
.
Transformer
model
architectures
.
Transformer
Engine
supports
the
following
DL
frameworks
:
PyTorch
,
JAX
(
Flax
,
Praxis
)
,
and
TensorFlow
.
Transformer
Engine
supports
the
following
DL
frameworks
:
PyTorch
and
JAX
(
Flax
,
Praxis
).
NOTE
:
For
simplicity
,
we
only
show
PyTorch
examples
below
.
For
the
usage
of
`
TransformerLayer
`
NOTE
:
For
simplicity
,
we
only
show
PyTorch
examples
below
.
For
the
usage
of
`
TransformerLayer
`
of
all
supported
frameworks
,
refer
to
`
examples
<
https
://
github
.
com
/
NVIDIA
/
TransformerEngine
/
tree
/
main
/
examples
>`
_
.
of
all
supported
frameworks
,
refer
to
`
examples
<
https
://
github
.
com
/
NVIDIA
/
TransformerEngine
/
tree
/
main
/
examples
>`
_
.
...
...
setup.py
View file @
35e687d0
...
@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
...
@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
def
setup_common_extension
()
->
CMakeExtension
:
def
setup_common_extension
()
->
CMakeExtension
:
"""Setup CMake extension for common library
"""Setup CMake extension for common library
Also builds JAX
, TensorFlow, and
userbuffers support if needed.
Also builds JAX
or
userbuffers support if needed.
"""
"""
cmake_flags
=
[]
cmake_flags
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment