Commit a7beacab authored by rusty1s's avatar rusty1s
Browse files

bugfix with tensor strides

parent aeca7758
......@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
import build # noqa
__version__ = '0.1.3'
__version__ = '0.2.0'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = ['cffi']
......
......@@ -8,6 +8,15 @@
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
},
{
"name": "add",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"grad": [[10, 20], [15, 25]],
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
},
{
"name": "mean",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -17,6 +26,15 @@
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
},
{
"name": "mean",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"grad": [[10, 20], [15, 25]],
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
},
{
"name": "max",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -25,5 +43,14 @@
"fill_value": 0,
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
"expected": [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
},
{
"name": "max",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"grad": [[10, 20], [15, 25]],
"expected": [[10, 0], [0, 25], [15, 0], [0, 20]]
}
]
......@@ -7,6 +7,14 @@
"fill_value": 0,
"expected": [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
},
{
"name": "add",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[6, 5], [6, 8]]
},
{
"name": "sub",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -15,6 +23,14 @@
"fill_value": 9,
"expected": [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]]
},
{
"name": "sub",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 9,
"expected": [[3, 4], [3, 1]]
},
{
"name": "mul",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -23,6 +39,14 @@
"fill_value": 1,
"expected": [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]]
},
{
"name": "mul",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 1,
"expected": [[5, 6], [8, 15]]
},
{
"name": "div",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -32,12 +56,12 @@
"expected": [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]]
},
{
"name": "mean",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
"name": "div",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[4, 2], [2, 1], [4, 2], [1, 2]],
"dim": 0,
"fill_value": 1,
"expected": [[0.25, 0.25], [0.125, 0.5]]
},
{
"name": "max",
......@@ -48,6 +72,15 @@
"expected": [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
"expected_arg": [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
},
{
"name": "max",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[5, 3], [4, 5]],
"expected_arg": [[0, 3], [2, 1]]
},
{
"name": "min",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......@@ -56,5 +89,14 @@
"fill_value": 9,
"expected": [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
"expected_arg": [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
},
{
"name": "min",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 9,
"expected": [[1, 2], [2, 3]],
"expected_arg": [[3, 0], [1, 2]]
}
]
......@@ -6,7 +6,7 @@ from .functions.mean import scatter_mean_, scatter_mean
from .functions.max import scatter_max_, scatter_max
from .functions.min import scatter_min_, scatter_min
__version__ = '0.1.3'
__version__ = '0.2.0'
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
......
......@@ -3,62 +3,67 @@
#else
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] *= input_data[i];
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] *= *(input_data + i * input_stride);
})
}
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] /= input_data[i];
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] /= *(input_data + i * input_stride);
})
}
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *count) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i];
count_data[index_data[i]]++;
output_data[idx * output_stride] += *(input_data + i * input_stride);
output_data[idx * count_stride]++;
})
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] >= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
arg_data[index_data[i]] = i;
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) >= *(output_data + idx * output_stride)) {
output_data[idx * output_stride] = *(input_data + i * input_stride);
arg_data[idx * arg_stride] = i;
}
})
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] <= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
arg_data[index_data[i]] = i;
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) <= *(output_data + idx * output_stride)) {
output_data[idx * output_stride] = *(input_data + i * input_stride);
arg_data[idx * arg_stride] = i;
}
})
}
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg) {
int64_t i;
int64_t i, idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
if (arg_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
idx = *(index_data + i * index_stride);
if (*(arg_data + idx * arg_stride) == i) output_data[i * output_stride] = *(grad_data + idx * grad_stride);
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment