Unverified Commit c9627e86 authored by liuhw's avatar liuhw Committed by GitHub
Browse files

[FIx] Fix arf op's write conflict when num_orientations is not 1 (#2824)

parent 50d1fffb
...@@ -15,18 +15,19 @@ __global__ void active_rotated_filter_forward_cuda_kernel( ...@@ -15,18 +15,19 @@ __global__ void active_rotated_filter_forward_cuda_kernel(
const int nthreads, const scalar_t* weight_data, const int* indices_data, const int nthreads, const scalar_t* weight_data, const int* indices_data,
const int num_input_planes, const int num_output_planes, const int num_input_planes, const int num_output_planes,
const int num_orientations, const int num_rotations, const int nEntry, const int num_orientations, const int num_rotations, const int nEntry,
scalar_t* output_data) { const int kH, const int kW, scalar_t* output_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry; int l = index % nEntry;
int j = (index / nEntry) % num_input_planes; int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes; int i = index / nEntry / num_input_planes;
int k; int k;
int fmIndex = (l / (kH * kW)) * kH * kW;
scalar_t val = *(weight_data + index); scalar_t val = *(weight_data + index);
for (k = 0; k < num_rotations; k++) { for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1; int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t* target = output_data + scalar_t* target =
i * (num_rotations * num_input_planes * nEntry) + output_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx; k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex;
*target = val; *target = val;
} }
} }
...@@ -37,12 +38,14 @@ __global__ void active_rotated_filter_backward_cuda_kernel( ...@@ -37,12 +38,14 @@ __global__ void active_rotated_filter_backward_cuda_kernel(
const int nthreads, const scalar_t* gradWeight_data, const int nthreads, const scalar_t* gradWeight_data,
const int* indices_data, const int num_input_planes, const int* indices_data, const int num_input_planes,
const int num_output_planes, const int num_orientations, const int num_output_planes, const int num_orientations,
const int num_rotations, const int nEntry, scalar_t* weight_data) { const int num_rotations, const int nEntry, const int kH, const int kW,
scalar_t* weight_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry; int l = index % nEntry;
int j = (index / nEntry) % num_input_planes; int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes; int i = index / nEntry / num_input_planes;
int k; int k;
int fmIndex = (l / (kH * kW)) * kH * kW;
scalar_t* val = weight_data + index; scalar_t* val = weight_data + index;
*val = 0; *val = 0;
scalar_t tmp = 0; scalar_t tmp = 0;
...@@ -50,7 +53,7 @@ __global__ void active_rotated_filter_backward_cuda_kernel( ...@@ -50,7 +53,7 @@ __global__ void active_rotated_filter_backward_cuda_kernel(
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1; int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t target = scalar_t target =
*(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) + *(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx); k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex);
tmp = tmp + target; tmp = tmp + target;
} }
*val = tmp; *val = tmp;
......
...@@ -19,11 +19,12 @@ void active_rotated_filter_forward_cpu_kernel( ...@@ -19,11 +19,12 @@ void active_rotated_filter_forward_cpu_kernel(
for (l = 0; l < nEntry; l++) { for (l = 0; l < nEntry; l++) {
int weightIndex = i * num_input_planes * nEntry + j * nEntry + l; int weightIndex = i * num_input_planes * nEntry + j * nEntry + l;
T val = *(weightData + weightIndex); T val = *(weightData + weightIndex);
int fmIndex = (l / (kH * kW)) * kH * kW;
for (k = 0; k < num_rotations; k++) { for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1; int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
T* target = outputData + T* target =
i * (num_rotations * num_input_planes * nEntry) + outputData + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index; k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex;
*target = val; *target = val;
} }
} }
...@@ -48,11 +49,12 @@ void active_rotated_filter_backward_cpu_kernel( ...@@ -48,11 +49,12 @@ void active_rotated_filter_backward_cpu_kernel(
int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l; int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l;
T* val = gradInputData + gradInputIndex; T* val = gradInputData + gradInputIndex;
*val = 0; *val = 0;
int fmIndex = (l / (kH * kW)) * kH * kW;
for (k = 0; k < num_rotations; k++) { for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1; int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
const T* target = const T* target =
gradOutputData + i * (num_rotations * num_input_planes * nEntry) + gradOutputData + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index; k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex;
*val = *val + *target; *val = *val + *target;
} }
} }
......
...@@ -24,7 +24,7 @@ void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input, ...@@ -24,7 +24,7 @@ void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(), output_size, input.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes, indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry, num_orientations, num_rotations, nEntry, kH, kW,
output.data_ptr<scalar_t>()); output.data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
...@@ -51,7 +51,7 @@ void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out, ...@@ -51,7 +51,7 @@ void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_out.data_ptr<scalar_t>(), output_size, grad_out.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes, indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry, num_orientations, num_rotations, nEntry, kH, kW,
grad_in.data_ptr<scalar_t>()); grad_in.data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment