Unverified Commit c53c8e49 authored by Leo's avatar Leo Committed by GitHub
Browse files

fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)



fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor."
Co-authored-by: default avatar刘长伟 <hzliuchw@corp.netease.com>
parent 04a5c859
...@@ -700,8 +700,13 @@ class BatchEncoding(UserDict): ...@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor is_tensor = torch.is_tensor
def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.tensor(np.array(value))
return torch.tensor(value)
elif tensor_type == TensorType.JAX: elif tensor_type == TensorType.JAX:
if not is_flax_available(): if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment