Commit 92814db2 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix compilation issues

parent 5775d20e
......@@ -514,7 +514,7 @@ void integrated_conv_kernel_backward(
// This group of (threads_per_kernel_pos) threads is responsible
// for position (kh, kw) in the kernel; we iterate over the patch.
scalar_t pos_add_val = pos_add_buf[pos_in_kernel],
pos_mul_val = = pos_mul_buf[pos_in_kernel];
pos_mul_val = pos_mul_buf[pos_in_kernel];
for (int pos_in_patch = threadIdx.x % threads_per_kernel_pos;
pos_in_patch < patch_size; pos_in_patch += threads_per_kernel_pos) {
......@@ -817,8 +817,10 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel", ([&] {
integrated_conv_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel_backward", ([&] {
integrated_conv_kernel_backward<scalar_t><<<gridDim, threads_per_block,
sizeof(scalar_t) * buffer_numel,
at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 4>(),
pos_add.packed_accessor32<scalar_t, 3>(),
pos_mul.packed_accessor32<scalar_t, 3>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment