Unverified Commit 321eb562 authored by Elad Segal's avatar Elad Segal Committed by GitHub
Browse files

`BatchFeature`: Convert `List[np.ndarray]` to `np.ndarray` before converting...


`BatchFeature`: Convert `List[np.ndarray]` to `np.ndarray` before converting to pytorch tensors (#14306)

* update

* style fix

* retrigger checks

* check first element

* fix syntax error

* Update src/transformers/feature_extraction_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove import
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 46d0cdae
......@@ -138,7 +138,11 @@ class BatchFeature(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
return torch.tensor(value)
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment