Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9995abcd
Commit
9995abcd
authored
May 08, 2023
by
flyingdown
Browse files
fix up rnnt for dcu
parent
b6c4b068
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
+4
-4
No files found.
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
View file @
9995abcd
...
...
@@ -115,7 +115,7 @@ __device__ void ComputeAlphas(
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up
_sync
(
0xffffffff
,
skip_prob
,
i
);
val
=
__shfl_up
(
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
...
...
@@ -139,7 +139,7 @@ __device__ void ComputeAlphas(
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up
_sync
(
0xffffffff
,
val
,
1
);
val
=
__shfl_up
(
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
...
...
@@ -214,7 +214,7 @@ __device__ void ComputeBetasCosts(
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up
_sync
(
0xffffffff
,
skip_prob
,
i
);
val
=
__shfl_up
(
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
...
...
@@ -237,7 +237,7 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up
_sync
(
0xffffffff
,
val
,
1
);
val
=
__shfl_up
(
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment