Unverified Commit 7cef7566 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Update required JAX version for FFI custom calls with cudaGraph (#1285)



Update jax version for ffi
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7b284fef
...@@ -162,7 +162,7 @@ def is_ffi_enabled(): ...@@ -162,7 +162,7 @@ def is_ffi_enabled():
""" """
Helper function checking if XLA Custom Call with FFI is enabled Helper function checking if XLA Custom Call with FFI is enabled
""" """
is_supported = jax_version_meet_requirement("0.4.31") is_supported = jax_version_meet_requirement("0.4.35")
# New APIs with FFI are enabled by default # New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment