Commit 43d09645 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Optimized type-changing of features from numpy to torch

- Bugfix: `torch` throws warnings when copying a tensor via initialization
- Added lambda to `.clone()` those tensors instead
parent 7f84eebd
......@@ -40,9 +40,10 @@ def np_to_tensor_dict(
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
"""
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
k: to_tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment