Unverified Commit 8ef3308a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Add default include path for XLA FFI (#1104)



* add default path for ffi include

* add an option to get XLA_HOME from env

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 67900e8d
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Paddle-paddle related extensions.""" """JAX related extensions."""
import os
from pathlib import Path from pathlib import Path
import setuptools import setuptools
...@@ -11,7 +12,24 @@ from glob import glob ...@@ -11,7 +12,24 @@ from glob import glob
from .utils import cuda_path, all_files_in_dir from .utils import cuda_path, all_files_in_dir
from typing import List from typing import List
from jax.extend import ffi
def xla_path() -> str:
"""XLA root path lookup.
Throws FileNotFoundError if XLA source is not found."""
try:
from jax.extend import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
else:
xla_home = "/opt/xla"
else:
xla_home = ffi.include_dir()
if not os.path.isdir(xla_home):
raise FileNotFoundError("Could not find xla source.")
return xla_home
def setup_jax_extension( def setup_jax_extension(
...@@ -29,14 +47,14 @@ def setup_jax_extension( ...@@ -29,14 +47,14 @@ def setup_jax_extension(
# Header files # Header files
cuda_home, _ = cuda_path() cuda_home, _ = cuda_path()
jax_ffi_include = ffi.include_dir() xla_home = xla_path()
include_dirs = [ include_dirs = [
cuda_home / "include", cuda_home / "include",
common_header_files, common_header_files,
common_header_files / "common", common_header_files / "common",
common_header_files / "common" / "include", common_header_files / "common" / "include",
csrc_header_files, csrc_header_files,
jax_ffi_include, xla_home,
] ]
# Compile flags # Compile flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment