"...git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "d5c8a4df32781cb12d7a165c814331d6d1de1eef"
Commit 143a57ec authored by rusty1s's avatar rusty1s
Browse files

searching for bug

parent 28f42bdc
...@@ -6,7 +6,8 @@ from torch_scatter import scatter_max_, scatter_max ...@@ -6,7 +6,8 @@ from torch_scatter import scatter_max_, scatter_max
from .utils import tensor_strs, Tensor from .utils import tensor_strs, Tensor
@pytest.mark.parametrize('str', tensor_strs) # @pytest.mark.parametrize('str', tensor_strs)
@pytest.mark.parametrize('str', ['DoubleTensor'])
def test_scatter_mean(str): def test_scatter_mean(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
......
...@@ -65,10 +65,14 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp ...@@ -65,10 +65,14 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}) })
} }
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) { void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
if (grad_index_data[index_data[i]] == i) output_data[index_data[i]] = grad_data[i]; int64_t idx = *(index_data + i * index_stride);
/* if (grad_index_data[index_data[i]] == i) { */
/* printf("i: %i, ", i); */
/* output_data[i] = grad_data[idx]; */
/* } */
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment