Commit 8422a6f5 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

aten api changes

parent 32f8cbcf
...@@ -6,13 +6,12 @@ ...@@ -6,13 +6,12 @@
#include <cstring> #include <cstring>
template <typename T> template <typename T>
at::Tensor rule_index_select(at::Tensor src, Int nRules, at::Tensor rule_index_select(at::Tensor src, Int nRules, Int *rules) {
Int *rules) {
auto n = src.size(1); auto n = src.size(1);
auto target = at::empty({nRules, n}, src.type()); auto target = at::empty({nRules, n}, src.options());
auto t_ptr = target.data<T>(); auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>(); auto s_ptr = src.data<T>();
#pragma omp parallel for #pragma omp parallel for
for (Int i = 0; i < nRules; ++i) for (Int i = 0; i < nRules; ++i)
std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n); std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n);
return target; return target;
...@@ -23,7 +22,7 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules, ...@@ -23,7 +22,7 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
auto t_ptr = target.data<T>(); auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>(); auto s_ptr = src.data<T>();
auto n = target.size(1); auto n = target.size(1);
#pragma omp parallel for #pragma omp parallel for
for (Int i = 0; i < nRules; ++i) { for (Int i = 0; i < nRules; ++i) {
auto t = t_ptr + rules[2 * i] * n; auto t = t_ptr + rules[2 * i] * n;
auto s = s_ptr + i * n; auto s = s_ptr + i * n;
...@@ -103,7 +102,8 @@ void cpu_Convolution_backward( ...@@ -103,7 +102,8 @@ void cpu_Convolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]); auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -179,7 +179,8 @@ void cpu_SubmanifoldConvolution_backward( ...@@ -179,7 +179,8 @@ void cpu_SubmanifoldConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]); auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -253,7 +254,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward( ...@@ -253,7 +254,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]); auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -335,7 +337,8 @@ void cpu_FullConvolution_backward( ...@@ -335,7 +337,8 @@ void cpu_FullConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]); auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
...@@ -414,7 +417,8 @@ void cpu_RandomizedStrideConvolution_backward( ...@@ -414,7 +417,8 @@ void cpu_RandomizedStrideConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t()); // auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows); // d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]); auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]); auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows); at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t()); auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]); rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment