Unverified Commit 321eb562 authored by Elad Segal's avatar Elad Segal Committed by GitHub
Browse files

`BatchFeature`: Convert `List[np.ndarray]` to `np.ndarray` before converting...


`BatchFeature`: Convert `List[np.ndarray]` to `np.ndarray` before converting to pytorch tensors (#14306)

* update

* style fix

* retrigger checks

* check first element

* fix syntax error

* Update src/transformers/feature_extraction_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove import
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 46d0cdae
...@@ -138,7 +138,11 @@ class BatchFeature(UserDict): ...@@ -138,7 +138,11 @@ class BatchFeature(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch import torch
as_tensor = torch.tensor def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
return torch.tensor(value)
is_tensor = torch.is_tensor is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX: elif tensor_type == TensorType.JAX:
if not is_flax_available(): if not is_flax_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment