Unverified Commit b4a286a1 authored by Xiaodong Wang's avatar Xiaodong Wang Committed by GitHub
Browse files

[AMD] hipify torchaudio

Differential Revision: D64184710

Pull Request resolved: https://github.com/pytorch/audio/pull/3840
parent 3f056993
......@@ -39,7 +39,11 @@ __global__ void ReduceMax2D(
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifndef USE_ROCM
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#else
shf = __shfl_down(val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
......@@ -81,7 +85,11 @@ __global__ void ReduceLogSumExpGivenMax2D(
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifndef USE_ROCM
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#else
shf = __shfl_down(val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
......
......@@ -126,7 +126,11 @@ __device__ void ComputeAlphas(
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#else
val = __shfl_up(skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
......@@ -150,7 +154,11 @@ __device__ void ComputeAlphas(
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, val, 1);
#else
val = __shfl_up(val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
......@@ -225,7 +233,11 @@ __device__ void ComputeBetasCosts(
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#else
val = __shfl_up(skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
......@@ -248,7 +260,11 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, val, 1);
#else
val = __shfl_up(val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment