Commit bb47653e authored by rusty1s's avatar rusty1s
Browse files

added mean tests

parent a7beacab
...@@ -63,6 +63,22 @@ ...@@ -63,6 +63,22 @@
"fill_value": 1, "fill_value": 1,
"expected": [[0.25, 0.25], [0.125, 0.5]] "expected": [[0.25, 0.25], [0.125, 0.5]]
}, },
{
"name": "mean",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
},
{
"name": "mean",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[3, 2.5], [3, 4]]
},
{ {
"name": "max", "name": "max",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], "index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
......
...@@ -26,9 +26,10 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in ...@@ -26,9 +26,10 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
int64_t i, idx; int64_t i, idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) { for (i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] += *(input_data + i * input_stride); output_data[idx * output_stride] += *(input_data + i * input_stride);
output_data[idx * count_stride]++; count_data[idx * count_stride]++;
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment