"docs/source/de/preprocessing.md" did not exist on "3015d12bfb8b5d65affe05476ee9fe636c7bba0e"
Unverified Commit 704b3f74 authored by Y4hL's avatar Y4hL Committed by GitHub
Browse files

Add mlx support to BatchEncoding.convert_to_tensors (#29406)

* Add mlx support

* Fix import order and use def instead of lambda

* Another fix for ruff format :)

* Add detecting mlx from repr, add is_mlx_array
parent 39ef3fb2
...@@ -48,6 +48,7 @@ from .utils import ( ...@@ -48,6 +48,7 @@ from .utils import (
extract_commit_hash, extract_commit_hash,
is_flax_available, is_flax_available,
is_jax_tensor, is_jax_tensor,
is_mlx_available,
is_numpy_array, is_numpy_array,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
...@@ -726,6 +727,16 @@ class BatchEncoding(UserDict): ...@@ -726,6 +727,16 @@ class BatchEncoding(UserDict):
as_tensor = jnp.array as_tensor = jnp.array
is_tensor = is_jax_tensor is_tensor = is_jax_tensor
elif tensor_type == TensorType.MLX:
if not is_mlx_available():
raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
import mlx.core as mx
as_tensor = mx.array
def is_tensor(obj):
return isinstance(obj, mx.array)
else: else:
def as_tensor(value, dtype=None): def as_tensor(value, dtype=None):
......
...@@ -134,6 +134,7 @@ from .import_utils import ( ...@@ -134,6 +134,7 @@ from .import_utils import (
is_keras_nlp_available, is_keras_nlp_available,
is_levenshtein_available, is_levenshtein_available,
is_librosa_available, is_librosa_available,
is_mlx_available,
is_natten_available, is_natten_available,
is_ninja_available, is_ninja_available,
is_nltk_available, is_nltk_available,
......
...@@ -28,7 +28,14 @@ from typing import Any, ContextManager, Iterable, List, Tuple ...@@ -28,7 +28,14 @@ from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np import numpy as np
from packaging import version from packaging import version
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy from .import_utils import (
get_torch_version,
is_flax_available,
is_mlx_available,
is_tf_available,
is_torch_available,
is_torch_fx_proxy,
)
if is_flax_available(): if is_flax_available():
...@@ -87,6 +94,8 @@ def infer_framework_from_repr(x): ...@@ -87,6 +94,8 @@ def infer_framework_from_repr(x):
return "jax" return "jax"
elif representation.startswith("<class 'numpy."): elif representation.startswith("<class 'numpy."):
return "np" return "np"
elif representation.startswith("<class 'mlx."):
return "mlx"
def _get_frameworks_and_test_func(x): def _get_frameworks_and_test_func(x):
...@@ -99,6 +108,7 @@ def _get_frameworks_and_test_func(x): ...@@ -99,6 +108,7 @@ def _get_frameworks_and_test_func(x):
"tf": is_tf_tensor, "tf": is_tf_tensor,
"jax": is_jax_tensor, "jax": is_jax_tensor,
"np": is_numpy_array, "np": is_numpy_array,
"mlx": is_mlx_array,
} }
preferred_framework = infer_framework_from_repr(x) preferred_framework = infer_framework_from_repr(x)
# We will test this one first, then numpy, then the others. # We will test this one first, then numpy, then the others.
...@@ -111,8 +121,8 @@ def _get_frameworks_and_test_func(x): ...@@ -111,8 +121,8 @@ def _get_frameworks_and_test_func(x):
def is_tensor(x): def is_tensor(x):
""" """
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array`
defined by `infer_framework_from_repr` in the order defined by `infer_framework_from_repr`
""" """
# This gives us a smart order to test the frameworks with the corresponding tests. # This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(x) framework_to_test_func = _get_frameworks_and_test_func(x)
...@@ -231,6 +241,19 @@ def is_jax_tensor(x): ...@@ -231,6 +241,19 @@ def is_jax_tensor(x):
return False if not is_flax_available() else _is_jax(x) return False if not is_flax_available() else _is_jax(x)
def _is_mlx(x):
import mx.core as mx
return isinstance(x, mx.array)
def is_mlx_array(x):
"""
Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
"""
return False if not is_mlx_available() else _is_mlx(x)
def to_py_obj(obj): def to_py_obj(obj):
""" """
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
...@@ -499,6 +522,7 @@ class TensorType(ExplicitEnum): ...@@ -499,6 +522,7 @@ class TensorType(ExplicitEnum):
TENSORFLOW = "tf" TENSORFLOW = "tf"
NUMPY = "np" NUMPY = "np"
JAX = "jax" JAX = "jax"
MLX = "mlx"
class ContextManagers: class ContextManagers:
......
...@@ -145,6 +145,7 @@ _tokenizers_available = _is_package_available("tokenizers") ...@@ -145,6 +145,7 @@ _tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio") _torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx") _torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision") _torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_torch_version = "N/A" _torch_version = "N/A"
...@@ -923,6 +924,10 @@ def is_jinja_available(): ...@@ -923,6 +924,10 @@ def is_jinja_available():
return _jinja_available return _jinja_available
def is_mlx_available():
return _mlx_available
# docstyle-ignore # docstyle-ignore
CV2_IMPORT_ERROR = """ CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with: {0} requires the OpenCV library but it was not found in your environment. You can install it with:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment