Unverified Commit 8ef3308a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Add default include path for XLA FFI (#1104)



* add default path for ffi include

* add an option to get XLA_HOME from env

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 67900e8d
......@@ -2,7 +2,8 @@
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
"""JAX related extensions."""
import os
from pathlib import Path
import setuptools
......@@ -11,7 +12,24 @@ from glob import glob
from .utils import cuda_path, all_files_in_dir
from typing import List
from jax.extend import ffi
def xla_path() -> str:
"""XLA root path lookup.
Throws FileNotFoundError if XLA source is not found."""
try:
from jax.extend import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
else:
xla_home = "/opt/xla"
else:
xla_home = ffi.include_dir()
if not os.path.isdir(xla_home):
raise FileNotFoundError("Could not find xla source.")
return xla_home
def setup_jax_extension(
......@@ -29,14 +47,14 @@ def setup_jax_extension(
# Header files
cuda_home, _ = cuda_path()
jax_ffi_include = ffi.include_dir()
xla_home = xla_path()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
jax_ffi_include,
xla_home,
]
# Compile flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment