Commit f3d82a17 authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 601e686e
...@@ -38,6 +38,7 @@ def test_scatter_max(str): ...@@ -38,6 +38,7 @@ def test_scatter_max(str):
@pytest.mark.parametrize('str', tensor_strs) @pytest.mark.parametrize('str', tensor_strs)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_scatter_cuda_max(str): def test_scatter_cuda_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
......
...@@ -12,7 +12,7 @@ def gen_filled_tensor(input, size, fill_value): ...@@ -12,7 +12,7 @@ def gen_filled_tensor(input, size, fill_value):
def gen_output(index, input, dim, dim_size, fill_value): def gen_output(index, input, dim, dim_size, fill_value):
if dim_size is None: if dim_size is None:
dim_size = index.max() + 1 dim_size = index.max() + 1
dim_size = dim.size if torch.is_tensor(input) else dim_size.data[0] dim_size = dim_size if torch.is_tensor(input) else dim_size.data[0]
size = list(index.size()) size = list(index.size())
size[dim] = dim_size size[dim] = dim_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment