Unverified Commit f8f9a679 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix type check (#12638)

parent 2dd9440d
...@@ -1715,10 +1715,10 @@ def is_tensor(x): ...@@ -1715,10 +1715,10 @@ def is_tensor(x):
return True return True
if is_flax_available(): if is_flax_available():
import jaxlib.xla_extension as jax_xla import jax.numpy as jnp
from jax.core import Tracer from jax.core import Tracer
if isinstance(x, (jax_xla.DeviceArray, Tracer)): if isinstance(x, (jnp.ndarray, Tracer)):
return True return True
return isinstance(x, np.ndarray) return isinstance(x, np.ndarray)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment