"examples/pytorch/vscode:/vscode.git/clone" did not exist on "dca2580b8494f2358ef2634bd8d16ffd213fe30f"
Commit b92c0456 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some compilation errors

parent f29fdd23
......@@ -25,9 +25,9 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
dtype scalar_t = input.dtype();
TORCH_CHECK(pos_add.dtype() == scalar_t &&
pos_mul.dtype() == scalar_t,
auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
......@@ -35,9 +35,10 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
auto input_a = input.accessor<scalar_t, 4>(),
pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_add.accessor<scalar_t, 3>(),
output_a = pos_add.accessor<scalar_t, 4>();
output_a = output.accessor<scalar_t, 4>();
auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_add.accessor<scalar_t, 3>();
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
auto src_input_a = input_a[n][c],
......@@ -56,12 +57,12 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a;
scalar_t relu = src + dest + this_pos_add_a[kh][kw];
if (relu > 0.0)
sum += relu * this_pos_mul_a;
sum += relu * this_pos_mul_a[kh][kw];
}
}
output_a[h][w] = sum;
this_output_a[h][w] = sum;
}
}
}
......@@ -81,64 +82,9 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
template <typename T_accessor, typename scalar_t>
inline
void discounted_sum_update(T_accessor &accessor, int batchsz, scalar_t gamma, int change_pos, int discounted_pos) {
for (int i=0; i<batchsz-3; i+=4) {
accessor[i+0][change_pos] += gamma * accessor[i+0][discounted_pos];
accessor[i+1][change_pos] += gamma * accessor[i+1][discounted_pos];
accessor[i+2][change_pos] += gamma * accessor[i+2][discounted_pos];
accessor[i+3][change_pos] += gamma * accessor[i+3][discounted_pos];
}
for (int i=(batchsz - (batchsz & 3)); i<batchsz; i++) {
accessor[i][change_pos] += gamma * accessor[i][discounted_pos];
}
}
torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_left_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=0; j<y.size(1); j++) {
int j_left = j-1;
if (j_left == -1) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_left);
}
}));
return y;
}
torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_right_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=y.size(1)-1; j>=0; j--) {
int j_right = j+1;
if (j_right == y.size(1)) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_right);
}
}));
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
m.def("integrated_conv_backward_cpu", &integrated_conv_forward_cpu, "Integrated convolution backward function (CPU)");
m.def("integrated_conv_backward_cpu", &integrated_conv_backward_cpu, "Integrated convolution backward function (CPU)");
}
......@@ -17,5 +17,5 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cuda", &integrated_conv_cuda, "Integrated convolution forward function (CUDA)");
m.def("integrated_conv_backward_cuda", &integrated_conv_forward_cuda, "Integrated convolution backward function (CUDA)");
m.def("integrated_conv_backward_cuda", &integrated_conv_backward_cuda, "Integrated convolution backward function (CUDA)");
}
......@@ -241,9 +241,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
dtype scalar_t = input.dtype();
TORCH_CHECK(pos_add.dtype() == scalar_t &&
pos_mul.dtype() == scalar_t,
auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment