"tests/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0024a5c66f90c7d3d02f7ef08a773aace6deb155"
Commit ba26dfb1 authored by rusty1s's avatar rusty1s
Browse files

faster

parent bb47653e
......@@ -3,9 +3,10 @@
#else
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] *= *(input_data + i * input_stride);
......@@ -13,9 +14,10 @@ void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] /= *(input_data + i * input_stride);
......@@ -23,9 +25,10 @@ void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *count) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] += *(input_data + i * input_stride);
......@@ -34,9 +37,10 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) >= *(output_data + idx * output_stride)) {
......@@ -47,9 +51,10 @@ void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) <= *(output_data + idx * output_stride)) {
......@@ -60,9 +65,10 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg) {
int64_t i, idx;
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg, dim,
for (i = 0; i < THLongTensor_size(index, dim); i++) {
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
if (*(arg_data + idx * arg_stride) == i) output_data[i * output_stride] = *(grad_data + idx * grad_stride);
})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment