Commit 9995abcd authored by flyingdown's avatar flyingdown
Browse files

fix up rnnt for dcu

parent b6c4b068
......@@ -115,7 +115,7 @@ __device__ void ComputeAlphas(
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
......@@ -139,7 +139,7 @@ __device__ void ComputeAlphas(
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
val = __shfl_up(val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
......@@ -214,7 +214,7 @@ __device__ void ComputeBetasCosts(
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
......@@ -237,7 +237,7 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
val = __shfl_up(val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment