Commit b92c0456 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some compilation errors

parent f29fdd23
...@@ -25,9 +25,9 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -25,9 +25,9 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
TORCH_CHECK(pos_add.device() == input.device() && TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(), pos_mul.device() == pos_add.device(),
"Input devices mismatch"); "Input devices mismatch");
dtype scalar_t = input.dtype(); auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.dtype() == scalar_t && TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.dtype() == scalar_t, pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch"); "Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W}, torch::Tensor output = torch::empty({N, C, H, W},
...@@ -35,9 +35,10 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -35,9 +35,10 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
auto input_a = input.accessor<scalar_t, 4>(), auto input_a = input.accessor<scalar_t, 4>(),
pos_add_a = pos_add.accessor<scalar_t, 3>(), output_a = output.accessor<scalar_t, 4>();
pos_mul_a = pos_add.accessor<scalar_t, 3>(), auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
output_a = pos_add.accessor<scalar_t, 4>(); pos_mul_a = pos_add.accessor<scalar_t, 3>();
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
auto src_input_a = input_a[n][c], auto src_input_a = input_a[n][c],
...@@ -56,12 +57,12 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -56,12 +57,12 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) && if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W)) static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w]; src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a; scalar_t relu = src + dest + this_pos_add_a[kh][kw];
if (relu > 0.0) if (relu > 0.0)
sum += relu * this_pos_mul_a; sum += relu * this_pos_mul_a[kh][kw];
} }
} }
output_a[h][w] = sum; this_output_a[h][w] = sum;
} }
} }
} }
...@@ -81,64 +82,9 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input, ...@@ -81,64 +82,9 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
template <typename T_accessor, typename scalar_t>
inline
void discounted_sum_update(T_accessor &accessor, int batchsz, scalar_t gamma, int change_pos, int discounted_pos) {
for (int i=0; i<batchsz-3; i+=4) {
accessor[i+0][change_pos] += gamma * accessor[i+0][discounted_pos];
accessor[i+1][change_pos] += gamma * accessor[i+1][discounted_pos];
accessor[i+2][change_pos] += gamma * accessor[i+2][discounted_pos];
accessor[i+3][change_pos] += gamma * accessor[i+3][discounted_pos];
}
for (int i=(batchsz - (batchsz & 3)); i<batchsz; i++) {
accessor[i][change_pos] += gamma * accessor[i][discounted_pos];
}
}
torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_left_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=0; j<y.size(1); j++) {
int j_left = j-1;
if (j_left == -1) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_left);
}
}));
return y;
}
torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_right_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=y.size(1)-1; j>=0; j--) {
int j_right = j+1;
if (j_right == y.size(1)) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_right);
}
}));
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)"); m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
m.def("integrated_conv_backward_cpu", &integrated_conv_forward_cpu, "Integrated convolution backward function (CPU)"); m.def("integrated_conv_backward_cpu", &integrated_conv_backward_cpu, "Integrated convolution backward function (CPU)");
} }
...@@ -17,5 +17,5 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input, ...@@ -17,5 +17,5 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cuda", &integrated_conv_cuda, "Integrated convolution forward function (CUDA)"); m.def("integrated_conv_cuda", &integrated_conv_cuda, "Integrated convolution forward function (CUDA)");
m.def("integrated_conv_backward_cuda", &integrated_conv_forward_cuda, "Integrated convolution backward function (CUDA)"); m.def("integrated_conv_backward_cuda", &integrated_conv_backward_cuda, "Integrated convolution backward function (CUDA)");
} }
...@@ -241,9 +241,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input, ...@@ -241,9 +241,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
TORCH_CHECK(pos_add.device() == input.device() && TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(), pos_mul.device() == pos_add.device(),
"Input devices mismatch"); "Input devices mismatch");
dtype scalar_t = input.dtype(); auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.dtype() == scalar_t && TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.dtype() == scalar_t, pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch"); "Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W}, torch::Tensor output = torch::empty({N, C, H, W},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment