Commit 8422a6f5 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

aten api changes

parent 32f8cbcf
......@@ -6,13 +6,12 @@
#include <cstring>
template <typename T>
at::Tensor rule_index_select(at::Tensor src, Int nRules,
Int *rules) {
at::Tensor rule_index_select(at::Tensor src, Int nRules, Int *rules) {
auto n = src.size(1);
auto target = at::empty({nRules, n}, src.type());
auto target = at::empty({nRules, n}, src.options());
auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>();
#pragma omp parallel for
#pragma omp parallel for
for (Int i = 0; i < nRules; ++i)
std::memcpy(t_ptr + i * n, s_ptr + rules[2 * i] * n, sizeof(T) * n);
return target;
......@@ -23,7 +22,7 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
auto t_ptr = target.data<T>();
auto s_ptr = src.data<T>();
auto n = target.size(1);
#pragma omp parallel for
#pragma omp parallel for
for (Int i = 0; i < nRules; ++i) {
auto t = t_ptr + rules[2 * i] * n;
auto s = s_ptr + i * n;
......@@ -103,7 +102,8 @@ void cpu_Convolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -179,7 +179,8 @@ void cpu_SubmanifoldConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -253,7 +254,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -335,7 +337,8 @@ void cpu_FullConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......@@ -414,7 +417,8 @@ void cpu_RandomizedStrideConvolution_backward(
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto input_rows = rule_index_select<T>(input_features, nRules, &r[0]);
auto d_output_rows = rule_index_select<T>(d_output_features, nRules, &r[1]);
auto d_output_rows =
rule_index_select<T>(d_output_features, nRules, &r[1]);
at::mm_out(dw, input_rows.t(), d_output_rows);
auto d_input_rows = at::mm(d_output_rows, w.t());
rule_index_add_<T>(d_input_features, d_input_rows, nRules, &r[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment