Unverified Commit 0bb82bb6 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

[bugfix] fix deadlock on ampere of angular backward kernel (#589)

parent b314360c
...@@ -522,6 +522,8 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -522,6 +522,8 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = angular_sublength * csubaev_offsets(type_j, type_k, num_species); IndexT subaev_offset = angular_sublength * csubaev_offsets(type_j, type_k, num_species);
float3 grad_vij = make_float3(0.f, 0.f, 0.f);
float3 grad_vik = make_float3(0.f, 0.f, 0.f);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -583,28 +585,36 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -583,28 +585,36 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
grad_vik_y *= grad_output_item; grad_vik_y *= grad_output_item;
grad_vik_z *= grad_output_item; grad_vik_z *= grad_output_item;
sdix_grad += (-grad_vij_x - grad_vik_x); grad_vij.x += grad_vij_x;
sdiy_grad += (-grad_vij_y - grad_vik_y); grad_vij.y += grad_vij_y;
sdiz_grad += (-grad_vij_z - grad_vik_z); grad_vij.z += grad_vij_z;
grad_vik.x += grad_vik_x;
grad_vik.y += grad_vik_y;
grad_vik.z += grad_vik_z;
}
}
}
if (!is_double_backward) {
sdix_grad += (-grad_vij.x - grad_vik.x);
sdiy_grad += (-grad_vij.y - grad_vik.y);
sdiz_grad += (-grad_vij.z - grad_vik.z);
for (int offset = 16; offset > 0; offset /= 2) { for (int offset = 16; offset > 0; offset /= 2) {
grad_vij_x += __shfl_down_sync(0xFFFFFFFF, grad_vij_x, offset); grad_vij.x += __shfl_down_sync(0xFFFFFFFF, grad_vij.x, offset);
grad_vij_y += __shfl_down_sync(0xFFFFFFFF, grad_vij_y, offset); grad_vij.y += __shfl_down_sync(0xFFFFFFFF, grad_vij.y, offset);
grad_vij_z += __shfl_down_sync(0xFFFFFFFF, grad_vij_z, offset); grad_vij.z += __shfl_down_sync(0xFFFFFFFF, grad_vij.z, offset);
grad_vik_x += __shfl_down_sync(0xFFFFFFFF, grad_vik_x, offset); grad_vik.x += __shfl_down_sync(0xFFFFFFFF, grad_vik.x, offset);
grad_vik_y += __shfl_down_sync(0xFFFFFFFF, grad_vik_y, offset); grad_vik.y += __shfl_down_sync(0xFFFFFFFF, grad_vik.y, offset);
grad_vik_z += __shfl_down_sync(0xFFFFFFFF, grad_vik_z, offset); grad_vik.z += __shfl_down_sync(0xFFFFFFFF, grad_vik.z, offset);
} }
if (laneIdx == 0) { if (laneIdx == 0) {
sdjx_grad[jj] += grad_vij_x; sdjx_grad[jj] += grad_vij.x;
sdjy_grad[jj] += grad_vij_y; sdjy_grad[jj] += grad_vij.y;
sdjz_grad[jj] += grad_vij_z; sdjz_grad[jj] += grad_vij.z;
sdjx_grad[kk] += grad_vik_x; sdjx_grad[kk] += grad_vik.x;
sdjy_grad[kk] += grad_vik_y; sdjy_grad[kk] += grad_vik.y;
sdjz_grad[kk] += grad_vik_z; sdjz_grad[kk] += grad_vik.z;
}
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment