Unverified Commit e7602a4c authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

[minor] test split_non_tensor(<Tensor>) case (#1000)

parent 2bd85c05
...@@ -78,6 +78,8 @@ def split_non_tensors( ...@@ -78,6 +78,8 @@ def split_non_tensors(
Split a tuple into a list of tensors and the rest with information Split a tuple into a list of tensors and the rest with information
for later reconstruction. for later reconstruction.
When called with a tensor X, will return: (x,), None
Usage:: Usage::
x = torch.Tensor([1]) x = torch.Tensor([1])
......
...@@ -99,6 +99,11 @@ def test_split_unpack(): ...@@ -99,6 +99,11 @@ def test_split_unpack():
x = torch.Tensor([1]) x = torch.Tensor([1])
y = torch.Tensor([2]) y = torch.Tensor([2])
# degenerate case, args is a single tensor.
tensors, packed_non_tensors = split_non_tensors(x)
assert tensors == (x,)
assert packed_non_tensors is None
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
assert tensors == (x, y) assert tensors == (x, y)
assert packed_non_tensors == { assert packed_non_tensors == {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment