"src/vscode:/vscode.git/clone" did not exist on "4a60b45d4c4c80fa934d33e51edfcd29f9795470"
Unverified Commit 79047bf6 authored by Xiaodong Wang's avatar Xiaodong Wang Committed by GitHub
Browse files

[AMD] Hipify torchaudio_decoder

Differential Revision: D64298970

Pull Request resolved: https://github.com/pytorch/audio/pull/3843
parent b4a286a1
......@@ -26,7 +26,6 @@
#ifndef __ctc_prefix_decoder_h_
#define __ctc_prefix_decoder_h_
#include <cuda_runtime.h>
#include <cstdint>
#include <tuple>
#include <vector>
......
......@@ -26,24 +26,6 @@
#ifndef __ctc_prefix_decoder_host_h_
#define __ctc_prefix_decoder_host_h_
#include <cuda_runtime.h>
#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)
#define CHECK(X, ERROR_INFO) \
do { \
auto result = (X); \
......
......@@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) {
}
inline __device__ int laneId() {
#ifndef USE_ROCM
int id;
asm("mov.s32 %0, %%laneid;" : "=r"(id));
return id;
#else
return __lane_id();
#endif
}
/**
* @brief Shuffle the data inside a warp
......
......@@ -12,7 +12,8 @@ namespace cu_ctc {
* @tparam IntType data type (checked only for integers)
*/
template <typename IntType>
constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) {
constexpr __host__ __device__ IntType
log2(IntType num, IntType ret = IntType(0)) {
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
}
......
......@@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
__device__ __forceinline__ void merge_buf_() {
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxBufLen>(val_buf_, idx_buf_);
buf_len_ = 0;
set_k_th_(); // contains warp sync
#pragma unroll
......@@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ == kMaxArrLen) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
val_buf_[i] = kDummy;
......@@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ != 0) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
}
}
......@@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) {
return (a + b - 1) / b;
}
template <typename IntType>
constexpr inline __device__ IntType roundUp256(IntType num) {
constexpr inline __host__ __device__ IntType roundUp256(IntType num) {
// return (num + 255) / 256 * 256;
constexpr int MASK = 255;
return (num + MASK) & (~MASK);
......
......@@ -25,8 +25,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime.h>
#include "include/ctc_prefix_decoder.h"
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder.h"
#include "../include/ctc_prefix_decoder_host.h"
#include "device_data_wrap.h"
#include "device_log_prob.cuh"
......
......@@ -23,12 +23,13 @@
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <float.h>
#include <algorithm>
#include "../include/ctc_prefix_decoder_host.h"
#include "ctc_fast_divmod.cuh"
#include "cub/cub.cuh"
#include "device_data_wrap.h"
#include "device_log_prob.cuh"
#include "include/ctc_prefix_decoder_host.h"
#include "bitonic_topk/warpsort_topk.cuh"
......@@ -630,7 +631,8 @@ int CTC_prob_first_step_V2(
num_of_subwarp, beam));
int smem_size =
block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int);
FirstMatrixFuns[fun_idx]<<<grid, threads_per_block, smem_size, stream>>>(
auto kernel = FirstMatrixFuns[fun_idx];
kernel<<<grid, threads_per_block, smem_size, stream>>>(
(*log_prob_struct),
step,
pprev,
......@@ -766,7 +768,8 @@ int CTC_prob_topK_V2(
int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity);
int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>(
num_of_subwarp, beam);
BitonicTopkFuns[fun_idx]<<<grid, block, smem_size, stream>>>(
auto kernel = BitonicTopkFuns[fun_idx];
kernel<<<grid, block, smem_size, stream>>>(
(*log_prob_struct),
step,
ptable,
......
......@@ -24,9 +24,26 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <cuda_runtime.h>
#include <iostream>
#include <vector>
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder_host.h"
#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)
namespace cu_ctc {
constexpr size_t ALIGN_BYTES = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment