Unverified Commit 9d031fbd authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX BUILD] Fixes for JAX 0.7.0 (#1936)



* Fix jax build
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9166d4df
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""JAX related extensions.""" """JAX related extensions."""
import os import os
from pathlib import Path from pathlib import Path
from packaging import version
import setuptools import setuptools
...@@ -27,7 +28,13 @@ def xla_path() -> str: ...@@ -27,7 +28,13 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found.""" Throws FileNotFoundError if XLA source is not found."""
try: try:
from jax.extend import ffi import jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
except ImportError: except ImportError:
if os.getenv("XLA_HOME"): if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME")) xla_home = Path(os.getenv("XLA_HOME"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment