Commit 556a8c06 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

refactor index_select

parent ed2a1c04
......@@ -6,15 +6,16 @@
#include <cstring>
template <typename T>
void rule_index_select(at::Tensor target, at::Tensor src, Int nRules,
at::Tensor rule_index_select(at::Tensor src, Int nRules,
Int *rules) {
auto n = src.size(1);
auto target = at::empty({nRules, n}, src.type());
auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>();
auto n = target.size(1);
Int i;
#pragma omp parallel for private(i)
for (i = 0; i < nRules; ++i)
#pragma omp parallel for
for (Int i = 0; i < nRules; ++i)
std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n);
return target;
}
template <typename T>
void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
......@@ -22,9 +23,8 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>();
auto n = target.size(1);
Int i;
#pragma omp parallel for private(i)
for (i = 0; i < nRules; ++i) {
#pragma omp parallel for
for (Int i = 0; i < nRules; ++i) {
auto t = t_ptr + rules[2 * i] * n;
auto s = s_ptr + i * n;
for (int j = 0; j < n; ++j)
......@@ -62,8 +62,7 @@ double cpu_Convolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
......@@ -90,8 +89,6 @@ void cpu_Convolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -105,10 +102,8 @@ void cpu_Convolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -144,8 +139,7 @@ double cpu_SubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
......@@ -171,8 +165,6 @@ void cpu_SubmanifoldConvolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -186,10 +178,8 @@ void cpu_SubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -224,8 +214,7 @@ double cpu_PermutohedralSubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
......@@ -250,8 +239,6 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -265,10 +252,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -307,8 +292,7 @@ double cpu_FullConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
......@@ -337,8 +321,6 @@ void cpu_FullConvolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -352,10 +334,8 @@ void cpu_FullConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -393,8 +373,7 @@ double cpu_RandomizedStrideConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
......@@ -421,8 +400,6 @@ void cpu_RandomizedStrideConvolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -436,10 +413,8 @@ void cpu_RandomizedStrideConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......
......@@ -34,8 +34,7 @@ double cpu_Deconvolution_updateOutput(
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 0), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[1]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[1]);
auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[0]);
......@@ -62,8 +61,6 @@ void cpu_Deconvolution_backward(
if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i];
int nRules = r.size() / 2;
......@@ -77,10 +74,8 @@ void cpu_Deconvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 1), d_input_rows);
auto input_rows = at::empty({nRules, ip}, d_output_features.type());
rule_index_select<T>(input_rows, input_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[0]);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[1]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[0]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment