Commit 556a8c06 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

refactor index_select

parent ed2a1c04
...@@ -6,15 +6,16 @@ ...@@ -6,15 +6,16 @@
#include <cstring> #include <cstring>
template <typename T> template <typename T>
void rule_index_select(at::Tensor target, at::Tensor src, Int nRules, at::Tensor rule_index_select(at::Tensor src, Int nRules,
Int *rules) { Int *rules) {
auto n = src.size(1);
auto target = at::empty({nRules, n}, src.type());
auto t_ptr = target.data<T>(); auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>(); auto s_ptr = src.data<T>();
auto n = target.size(1); #pragma omp parallel for
Int i; for (Int i = 0; i < nRules; ++i)
#pragma omp parallel for private(i)
for (i = 0; i < nRules; ++i)
std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n); std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n);
return target;
} }
template <typename T> template <typename T>
void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules, void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
...@@ -22,9 +23,8 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules, ...@@ -22,9 +23,8 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
auto t_ptr = target.data<T>(); auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>(); auto s_ptr = src.data<T>();
auto n = target.size(1); auto n = target.size(1);
Int i; #pragma omp parallel for
#pragma omp parallel for private(i) for (Int i = 0; i < nRules; ++i) {
for (i = 0; i < nRules; ++i) {
auto t = t_ptr + rules[2 * i] * n; auto t = t_ptr + rules[2 * i] * n;
auto s = s_ptr + i * n; auto s = s_ptr + i * n;
for (int j = 0; j < n; ++j) for (int j = 0; j < n; ++j)
...@@ -62,8 +62,7 @@ double cpu_Convolution_updateOutput( ...@@ -62,8 +62,7 @@ double cpu_Convolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows); // output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]); rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
...@@ -90,8 +89,6 @@ void cpu_Convolution_backward( ...@@ -90,8 +89,6 @@ void cpu_Convolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -105,10 +102,8 @@ void cpu_Convolution_backward( ...@@ -105,10 +102,8 @@ void cpu_Convolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -144,8 +139,7 @@ double cpu_SubmanifoldConvolution_updateOutput( ...@@ -144,8 +139,7 @@ double cpu_SubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows); // output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]); rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
...@@ -171,8 +165,6 @@ void cpu_SubmanifoldConvolution_backward( ...@@ -171,8 +165,6 @@ void cpu_SubmanifoldConvolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -186,10 +178,8 @@ void cpu_SubmanifoldConvolution_backward( ...@@ -186,10 +178,8 @@ void cpu_SubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -224,8 +214,7 @@ double cpu_PermutohedralSubmanifoldConvolution_updateOutput( ...@@ -224,8 +214,7 @@ double cpu_PermutohedralSubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows); // output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]); rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
...@@ -250,8 +239,6 @@ void cpu_PermutohedralSubmanifoldConvolution_backward( ...@@ -250,8 +239,6 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -265,10 +252,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward( ...@@ -265,10 +252,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -307,8 +292,7 @@ double cpu_FullConvolution_updateOutput( ...@@ -307,8 +292,7 @@ double cpu_FullConvolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows); // output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]); rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
...@@ -337,8 +321,6 @@ void cpu_FullConvolution_backward( ...@@ -337,8 +321,6 @@ void cpu_FullConvolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -352,10 +334,8 @@ void cpu_FullConvolution_backward( ...@@ -352,10 +334,8 @@ void cpu_FullConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -393,8 +373,7 @@ double cpu_RandomizedStrideConvolution_updateOutput( ...@@ -393,8 +373,7 @@ double cpu_RandomizedStrideConvolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows); // output_features.index_add_(0, rt.select(1, 1), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[1]); rule_index_add_<T>(output_features, output_rows, nRules, &r[1]);
...@@ -421,8 +400,6 @@ void cpu_RandomizedStrideConvolution_backward( ...@@ -421,8 +400,6 @@ void cpu_RandomizedStrideConvolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -436,10 +413,8 @@ void cpu_RandomizedStrideConvolution_backward( ...@@ -436,10 +413,8 @@ void cpu_RandomizedStrideConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
rule_index_select<T>(input_rows, input_features, nRules, &r[0]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......
...@@ -34,8 +34,7 @@ double cpu_Deconvolution_updateOutput( ...@@ -34,8 +34,7 @@ double cpu_Deconvolution_updateOutput(
// auto w = weight.select(0, i); // auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w); // auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 0), output_rows); // output_features.index_add_(0, rt.select(1, 0), output_rows);
auto input_rows = at::empty({nRules, ip}, input_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[1]);
rule_index_select<T>(input_rows, input_features, nRules, &r[1]);
auto w = weight.select(0, i); auto w = weight.select(0, i);
auto output_rows = at::mm(input_rows, w); auto output_rows = at::mm(input_rows, w);
rule_index_add_<T>(output_features, output_rows, nRules, &r[0]); rule_index_add_<T>(output_features, output_rows, nRules, &r[0]);
...@@ -62,8 +61,6 @@ void cpu_Deconvolution_backward( ...@@ -62,8 +61,6 @@ void cpu_Deconvolution_backward(
if (nActive and d_bias.numel()) if (nActive and d_bias.numel())
at::sum_out(d_bias, d_output_features, {0}, false); at::sum_out(d_bias, d_output_features, {0}, false);
auto ip = weight.size(1);
auto op = weight.size(2);
for (Int i = 0; i < (Int)_rules.size(); i++) { for (Int i = 0; i < (Int)_rules.size(); i++) {
auto r = _rules[i]; auto r = _rules[i];
int nRules = r.size() / 2; int nRules = r.size() / 2;
...@@ -77,10 +74,8 @@ void cpu_Deconvolution_backward( ...@@ -77,10 +74,8 @@ void cpu_Deconvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows); // at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 1), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 1), d_input_rows);
auto input_rows = at::empty({nRules, ip}, d_output_features.type()); auto input_rows = rule_index_select<T>(input_features, nRules, &r[1]);
rule_index_select<T>(input_rows, input_features, nRules, &r[1]); auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[0]);
auto d_output_rows = at::empty({nRules, op}, d_output_features.type());
rule_index_select<T>(d_output_rows, d_output_features, nRules, &r[0]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[1]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment