Unverified Commit 01ae3b87 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Avoid `jnp` import in `utils/generic.py` (#30322)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 60d5f8f9
...@@ -38,10 +38,6 @@ from .import_utils import ( ...@@ -38,10 +38,6 @@ from .import_utils import (
) )
if is_flax_available():
import jax.numpy as jnp
class cached_property(property): class cached_property(property):
""" """
Descriptor that mimics @property but caches output in member variable. Descriptor that mimics @property but caches output in member variable.
...@@ -624,6 +620,8 @@ def transpose(array, axes=None): ...@@ -624,6 +620,8 @@ def transpose(array, axes=None):
return tf.transpose(array, perm=axes) return tf.transpose(array, perm=axes)
elif is_jax_tensor(array): elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.transpose(array, axes=axes) return jnp.transpose(array, axes=axes)
else: else:
raise ValueError(f"Type not supported for transpose: {type(array)}.") raise ValueError(f"Type not supported for transpose: {type(array)}.")
...@@ -643,6 +641,8 @@ def reshape(array, newshape): ...@@ -643,6 +641,8 @@ def reshape(array, newshape):
return tf.reshape(array, newshape) return tf.reshape(array, newshape)
elif is_jax_tensor(array): elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.reshape(array, newshape) return jnp.reshape(array, newshape)
else: else:
raise ValueError(f"Type not supported for reshape: {type(array)}.") raise ValueError(f"Type not supported for reshape: {type(array)}.")
...@@ -662,6 +662,8 @@ def squeeze(array, axis=None): ...@@ -662,6 +662,8 @@ def squeeze(array, axis=None):
return tf.squeeze(array, axis=axis) return tf.squeeze(array, axis=axis)
elif is_jax_tensor(array): elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.squeeze(array, axis=axis) return jnp.squeeze(array, axis=axis)
else: else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.") raise ValueError(f"Type not supported for squeeze: {type(array)}.")
...@@ -681,6 +683,8 @@ def expand_dims(array, axis): ...@@ -681,6 +683,8 @@ def expand_dims(array, axis):
return tf.expand_dims(array, axis=axis) return tf.expand_dims(array, axis=axis)
elif is_jax_tensor(array): elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.expand_dims(array, axis=axis) return jnp.expand_dims(array, axis=axis)
else: else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment