Commit 30ab3e6b authored by rusty1s's avatar rusty1s
Browse files

cpu done

parent 8fa47ae8
...@@ -76,7 +76,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -76,7 +76,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_forward", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
...@@ -137,7 +137,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis, ...@@ -137,7 +137,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
auto kernel_size_data = kernel_size.data_ptr<int64_t>(); auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_backward", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>(); auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#include "compat.h"
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t v;
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
tmp *= b * x_data[e * x.stride(0) + m_in * x.stride(1)];
v += tmp;
}
}
out_data[e * M_out + m_out] = v;
}
}
});
return out;
}
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_x_data = grad_x.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
grad_x_data[e * M_in + m_in] += g * b * w;
}
}
}
}
});
return grad_x;
}
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_weight_data = grad_weight.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
}
}
}
}
});
return grad_weight;
}
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_basis_data = grad_basis.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
scalar_t b = 0;
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
w *= x_data[e * x.stride(0) + m_in * x.stride(1)];
b += w;
}
grad_basis_data[e * S + s] += g * b;
}
}
}
});
return grad_basis;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CPU)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CPU)");
}
...@@ -5,14 +5,97 @@ ...@@ -5,14 +5,97 @@
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight, torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis, torch::Tensor basis,
torch::Tensor weight_index) { torch::Tensor weight_index) {
return x; CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t v;
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
v = 0;
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
tmp *= b * x_data[e * x.stride(0) + m_in * x.stride(1)];
v += tmp;
}
}
out_data[e * M_out + m_out] = v;
}
}
});
return out;
} }
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out, torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor basis, torch::Tensor basis,
torch::Tensor weight_index) { torch::Tensor weight_index) {
return grad_out; CHECK_CPU(grad_out);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
grad_x_data[e * M_in + m_in] += g * b * w;
}
}
}
}
});
return grad_x;
} }
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out, torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
...@@ -20,12 +103,91 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out, ...@@ -20,12 +103,91 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor basis, torch::Tensor basis,
torch::Tensor weight_index, torch::Tensor weight_index,
int64_t kernel_size) { int64_t kernel_size) {
return grad_out; CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
}
}
}
}
});
return grad_weight;
} }
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out, torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor x,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor weight_index) { torch::Tensor weight_index) {
return grad_out; CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
scalar_t b = 0;
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
w *= x_data[e * x.stride(0) + m_in * x.stride(1)];
b += w;
}
grad_basis_data[e * S + s] += g * b;
}
}
}
});
return grad_basis;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment