Unverified Commit 79047bf6 authored by Xiaodong Wang's avatar Xiaodong Wang Committed by GitHub
Browse files

[AMD] Hipify torchaudio_decoder

Differential Revision: D64298970

Pull Request resolved: https://github.com/pytorch/audio/pull/3843
parent b4a286a1
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#ifndef __ctc_prefix_decoder_h_ #ifndef __ctc_prefix_decoder_h_
#define __ctc_prefix_decoder_h_ #define __ctc_prefix_decoder_h_
#include <cuda_runtime.h>
#include <cstdint> #include <cstdint>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
......
...@@ -26,24 +26,6 @@ ...@@ -26,24 +26,6 @@
#ifndef __ctc_prefix_decoder_host_h_ #ifndef __ctc_prefix_decoder_host_h_
#define __ctc_prefix_decoder_host_h_ #define __ctc_prefix_decoder_host_h_
#include <cuda_runtime.h>
#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)
#define CHECK(X, ERROR_INFO) \ #define CHECK(X, ERROR_INFO) \
do { \ do { \
auto result = (X); \ auto result = (X); \
......
...@@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) { ...@@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) {
} }
inline __device__ int laneId() { inline __device__ int laneId() {
#ifndef USE_ROCM
int id; int id;
asm("mov.s32 %0, %%laneid;" : "=r"(id)); asm("mov.s32 %0, %%laneid;" : "=r"(id));
return id; return id;
#else
return __lane_id();
#endif
} }
/** /**
* @brief Shuffle the data inside a warp * @brief Shuffle the data inside a warp
......
...@@ -12,7 +12,8 @@ namespace cu_ctc { ...@@ -12,7 +12,8 @@ namespace cu_ctc {
* @tparam IntType data type (checked only for integers) * @tparam IntType data type (checked only for integers)
*/ */
template <typename IntType> template <typename IntType>
constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) { constexpr __host__ __device__ IntType
log2(IntType num, IntType ret = IntType(0)) {
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
} }
......
...@@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> { ...@@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
__device__ __forceinline__ void merge_buf_() { __device__ __forceinline__ void merge_buf_() {
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_); this->template merge_in<kMaxBufLen>(val_buf_, idx_buf_);
buf_len_ = 0; buf_len_ = 0;
set_k_th_(); // contains warp sync set_k_th_(); // contains warp sync
#pragma unroll #pragma unroll
...@@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> { ...@@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ == kMaxArrLen) { if (buf_len_ == kMaxArrLen) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth) topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_); .sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_); this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
#pragma unroll #pragma unroll
for (int i = 0; i < kMaxArrLen; i++) { for (int i = 0; i < kMaxArrLen; i++) {
val_buf_[i] = kDummy; val_buf_[i] = kDummy;
...@@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> { ...@@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ != 0) { if (buf_len_ != 0) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth) topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_); .sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_); this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
} }
} }
...@@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) { ...@@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
template <typename IntType> template <typename IntType>
constexpr inline __device__ IntType roundUp256(IntType num) { constexpr inline __host__ __device__ IntType roundUp256(IntType num) {
// return (num + 255) / 256 * 256; // return (num + 255) / 256 * 256;
constexpr int MASK = 255; constexpr int MASK = 255;
return (num + MASK) & (~MASK); return (num + MASK) & (~MASK);
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "include/ctc_prefix_decoder.h" #include "../include/ctc_prefix_decoder.h"
#include "include/ctc_prefix_decoder_host.h" #include "../include/ctc_prefix_decoder_host.h"
#include "device_data_wrap.h" #include "device_data_wrap.h"
#include "device_log_prob.cuh" #include "device_log_prob.cuh"
......
...@@ -23,12 +23,13 @@ ...@@ -23,12 +23,13 @@
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <float.h>
#include <algorithm> #include <algorithm>
#include "../include/ctc_prefix_decoder_host.h"
#include "ctc_fast_divmod.cuh" #include "ctc_fast_divmod.cuh"
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "device_data_wrap.h" #include "device_data_wrap.h"
#include "device_log_prob.cuh" #include "device_log_prob.cuh"
#include "include/ctc_prefix_decoder_host.h"
#include "bitonic_topk/warpsort_topk.cuh" #include "bitonic_topk/warpsort_topk.cuh"
...@@ -630,7 +631,8 @@ int CTC_prob_first_step_V2( ...@@ -630,7 +631,8 @@ int CTC_prob_first_step_V2(
num_of_subwarp, beam)); num_of_subwarp, beam));
int smem_size = int smem_size =
block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int); block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int);
FirstMatrixFuns[fun_idx]<<<grid, threads_per_block, smem_size, stream>>>( auto kernel = FirstMatrixFuns[fun_idx];
kernel<<<grid, threads_per_block, smem_size, stream>>>(
(*log_prob_struct), (*log_prob_struct),
step, step,
pprev, pprev,
...@@ -766,7 +768,8 @@ int CTC_prob_topK_V2( ...@@ -766,7 +768,8 @@ int CTC_prob_topK_V2(
int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity); int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity);
int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>( int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>(
num_of_subwarp, beam); num_of_subwarp, beam);
BitonicTopkFuns[fun_idx]<<<grid, block, smem_size, stream>>>( auto kernel = BitonicTopkFuns[fun_idx];
kernel<<<grid, block, smem_size, stream>>>(
(*log_prob_struct), (*log_prob_struct),
step, step,
ptable, ptable,
......
...@@ -24,9 +24,26 @@ ...@@ -24,9 +24,26 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once #pragma once
#include <cuda_runtime.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "include/ctc_prefix_decoder_host.h" #include "../include/ctc_prefix_decoder_host.h"
#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)
namespace cu_ctc { namespace cu_ctc {
constexpr size_t ALIGN_BYTES = 128; constexpr size_t ALIGN_BYTES = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment