Unverified Commit c53c8e49 authored by Leo's avatar Leo Committed by GitHub
Browse files

fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)



fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor."
Co-authored-by: default avatar刘长伟 <hzliuchw@corp.netease.com>
parent 04a5c859
......@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor
def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.tensor(np.array(value))
return torch.tensor(value)
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment