Unverified Commit d3157e2a authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Make an import optional for a JAX API change that will happen soon. (#454)


Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
parent 5f60f82f
...@@ -11,7 +11,6 @@ import operator ...@@ -11,7 +11,6 @@ import operator
import warnings import warnings
import numpy as np import numpy as np
from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp import jax.numpy as jnp
from jax.lib import xla_client from jax.lib import xla_client
from jax import core, dtypes from jax import core, dtypes
...@@ -19,6 +18,13 @@ from jax.core import ShapedArray ...@@ -19,6 +18,13 @@ from jax.core import ShapedArray
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
from jax.interpreters.mlir import ir, dtype_to_ir_type from jax.interpreters.mlir import ir, dtype_to_ir_type
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment