Commit 9995abcd authored by flyingdown's avatar flyingdown
Browse files

fix up rnnt for dcu

parent b6c4b068
...@@ -115,7 +115,7 @@ __device__ void ComputeAlphas( ...@@ -115,7 +115,7 @@ __device__ void ComputeAlphas(
#pragma unroll #pragma unroll
for (int i = 1; i < warpSize; i <<= 1) { for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i); val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) { if (i <= threadIdx.x) {
skip_prob = skip_prob + val; skip_prob = skip_prob + val;
} }
...@@ -139,7 +139,7 @@ __device__ void ComputeAlphas( ...@@ -139,7 +139,7 @@ __device__ void ComputeAlphas(
CAST_DTYPE out = val; CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) { for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1); val = __shfl_up(val, 1);
if (i == threadIdx.x) { if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit); val = math::lse(val + skip_prob, emit);
out = val; out = val;
...@@ -214,7 +214,7 @@ __device__ void ComputeBetasCosts( ...@@ -214,7 +214,7 @@ __device__ void ComputeBetasCosts(
#pragma unroll #pragma unroll
for (int i = 1; i < warpSize; i <<= 1) { for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i); val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) { if (i <= threadIdx.x) {
skip_prob = skip_prob + val; skip_prob = skip_prob + val;
} }
...@@ -237,7 +237,7 @@ __device__ void ComputeBetasCosts( ...@@ -237,7 +237,7 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE out = val; CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) { for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1); val = __shfl_up(val, 1);
if (i == threadIdx.x) { if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit); val = math::lse(val + skip_prob, emit);
out = val; out = val;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment