Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
0bb82bb6
Unverified
Commit
0bb82bb6
authored
Apr 15, 2021
by
Jinze (Richard) Xue
Committed by
GitHub
Apr 15, 2021
Browse files
[bugfix] fix deadlock on ampere of angular backward kernel (#589)
parent
b314360c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
21 deletions
+31
-21
torchani/cuaev/aev.cu
torchani/cuaev/aev.cu
+31
-21
No files found.
torchani/cuaev/aev.cu
View file @
0bb82bb6
...
@@ -522,6 +522,8 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
...
@@ -522,6 +522,8 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
DataT
fc_ijk
=
fc_ij
*
fc_ik
;
DataT
fc_ijk
=
fc_ij
*
fc_ik
;
IndexT
subaev_offset
=
angular_sublength
*
csubaev_offsets
(
type_j
,
type_k
,
num_species
);
IndexT
subaev_offset
=
angular_sublength
*
csubaev_offsets
(
type_j
,
type_k
,
num_species
);
float3
grad_vij
=
make_float3
(
0.
f
,
0.
f
,
0.
f
);
float3
grad_vik
=
make_float3
(
0.
f
,
0.
f
,
0.
f
);
for
(
int
itheta
=
tile
.
x
;
itheta
<
nShfZ
;
itheta
+=
TILEX
)
{
for
(
int
itheta
=
tile
.
x
;
itheta
<
nShfZ
;
itheta
+=
TILEX
)
{
DataT
ShfZ
=
ShfZ_t
[
itheta
];
DataT
ShfZ
=
ShfZ_t
[
itheta
];
...
@@ -583,28 +585,36 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
...
@@ -583,28 +585,36 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
grad_vik_y
*=
grad_output_item
;
grad_vik_y
*=
grad_output_item
;
grad_vik_z
*=
grad_output_item
;
grad_vik_z
*=
grad_output_item
;
sdix_grad
+=
(
-
grad_vij_x
-
grad_vik_x
);
grad_vij
.
x
+=
grad_vij_x
;
sdiy_grad
+=
(
-
grad_vij_y
-
grad_vik_y
);
grad_vij
.
y
+=
grad_vij_y
;
sdiz_grad
+=
(
-
grad_vij_z
-
grad_vik_z
);
grad_vij
.
z
+=
grad_vij_z
;
grad_vik
.
x
+=
grad_vik_x
;
grad_vik
.
y
+=
grad_vik_y
;
grad_vik
.
z
+=
grad_vik_z
;
}
}
}
if
(
!
is_double_backward
)
{
sdix_grad
+=
(
-
grad_vij
.
x
-
grad_vik
.
x
);
sdiy_grad
+=
(
-
grad_vij
.
y
-
grad_vik
.
y
);
sdiz_grad
+=
(
-
grad_vij
.
z
-
grad_vik
.
z
);
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
{
grad_vij
_
x
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
_
x
,
offset
);
grad_vij
.
x
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
.
x
,
offset
);
grad_vij
_
y
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
_
y
,
offset
);
grad_vij
.
y
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
.
y
,
offset
);
grad_vij
_
z
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
_
z
,
offset
);
grad_vij
.
z
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vij
.
z
,
offset
);
grad_vik
_
x
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
_
x
,
offset
);
grad_vik
.
x
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
.
x
,
offset
);
grad_vik
_
y
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
_
y
,
offset
);
grad_vik
.
y
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
.
y
,
offset
);
grad_vik
_
z
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
_
z
,
offset
);
grad_vik
.
z
+=
__shfl_down_sync
(
0xFFFFFFFF
,
grad_vik
.
z
,
offset
);
}
}
if
(
laneIdx
==
0
)
{
if
(
laneIdx
==
0
)
{
sdjx_grad
[
jj
]
+=
grad_vij
_
x
;
sdjx_grad
[
jj
]
+=
grad_vij
.
x
;
sdjy_grad
[
jj
]
+=
grad_vij
_
y
;
sdjy_grad
[
jj
]
+=
grad_vij
.
y
;
sdjz_grad
[
jj
]
+=
grad_vij
_
z
;
sdjz_grad
[
jj
]
+=
grad_vij
.
z
;
sdjx_grad
[
kk
]
+=
grad_vik_x
;
sdjx_grad
[
kk
]
+=
grad_vik
.
x
;
sdjy_grad
[
kk
]
+=
grad_vik_y
;
sdjy_grad
[
kk
]
+=
grad_vik
.
y
;
sdjz_grad
[
kk
]
+=
grad_vik_z
;
sdjz_grad
[
kk
]
+=
grad_vik
.
z
;
}
}
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment