Commit 4af63f28 authored by rusty1s's avatar rusty1s
Browse files

min max

parent d7353409
......@@ -37,3 +37,19 @@ void scatter_mean_Char (int dim, THCharTensor *output, THLongTensor *index, T
void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THShortTensor *output_count);
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index);
......@@ -53,4 +53,26 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
})
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
int64_t idx; real old, new;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim, TH_TENSOR_DIM_APPLY4_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
old = output_data[idx]; new = *(input_data + i * input_stride);
if (new >= old) { output_data[idx] = new; output_index_data[idx] = i; }
})
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
int64_t idx; real old, new;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim, TH_TENSOR_DIM_APPLY4_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
old = output_data[idx]; new = *(input_data + i * input_stride);
if (new <= old) { output_data[idx] = new; output_index_data[idx] = i; }
})
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment