Commit 43d09645 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Optimized type-changing of features from numpy to torch

- Bugfix: `torch` throws warnings when copying a tensor via initialization
- Added lambda to `.clone()` those tensors instead
parent 7f84eebd
...@@ -40,9 +40,10 @@ def np_to_tensor_dict( ...@@ -40,9 +40,10 @@ def np_to_tensor_dict(
Returns: Returns:
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = { tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features k: to_tensor(v) for k, v in np_example.items() if k in features
} }
return tensor_dict return tensor_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment