Commit 0a1801ed authored by Yuekai Zhang's avatar Yuekai Zhang Committed by Facebook GitHub Bot
Browse files

Add cuctc decoder (#3096)

Summary:
This PR implements a CUDA based ctc prefix beam search decoder.

Attach serveral benchmark results using V100 below:
|decoder type| model |datasets       | decoding time (secs)| beam size | batch size | model unit | subsampling times | vocab size |
|--------------|---------|------|-----------------|------------|-------------|------------|-----------------------|------------|
| cuctc |  conformer nemo    |dev clean        |7.68s | 8           |  32       | bpe         |    4  | 1000|
| cuctc |  conformer nemo   |dev clean  (sort by length)      |1.6s | 8           |  32       | bpe         |    4  | 1000|
| cuctc |  wav2vec2.0 torchaudio |dev clean                                |22s | 10           |  1       | char         |    2  | 29|
| cuctc |   conformer espnet   |aishell1 test                             | 5s | 10           |  24       | char         |    4  | 4233|

Note:
1.  The design is to parallel computation through batch and vocab axis, for loop the frames axis. So it's more friendly with smaller sequence lengths, larger vocab size comparing with CPU implementations.
2. WER is the same as CPU implementations. However, it can't decode with LM now.

Resolves: https://github.com/pytorch/audio/issues/2957.

Pull Request resolved: https://github.com/pytorch/audio/pull/3096

Reviewed By: nateanl

Differential Revision: D44709397

Pulled By: mthrok

fbshipit-source-id: 3078c54a2b44dc00eb4a81b4c657487eeff8c155
parent 151ac4d8
/**
* Modified from NVIDIA/cutlass(https://github.com/NVIDIA/cutlass)
*
*/
#pragma once
namespace cu_ctc {
template <typename value_t>
__host__ __device__ __forceinline__ value_t clz(value_t x) {
for (int i = 31; i >= 0; --i) {
if ((1 << i) & x)
return 31 - i;
}
return 32;
}
template <typename value_t>
__host__ __device__ __forceinline__ value_t find_log2(value_t x) {
int a = int(31 - clz(x));
a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
return a;
}
/**
* Find divisor, using find_log2
*/
__host__ __device__ __forceinline__ void find_divisor(
unsigned int& mul,
unsigned int& shr,
unsigned int denom) {
if (denom == 1) {
mul = 0;
shr = 0;
} else {
unsigned int p = 31 + find_log2(denom);
unsigned m =
unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom));
mul = m;
shr = p - 32;
}
}
__host__ __device__ __forceinline__
void
fast_divmod(
int& quo,
int& rem,
int src,
int div,
unsigned int mul,
unsigned int shr) {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if div != 1, else simply copy the source.
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
#else
quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src);
#endif
// The remainder.
rem = src - (quo * div);
}
// For long int input
__host__ __device__ __forceinline__ void fast_divmod(
int& quo,
int64_t& rem,
int64_t src,
int div,
unsigned int mul,
unsigned int shr) {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if div != 1, else simply copy the source.
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
#else
quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src);
#endif
// The remainder.
rem = src - (quo * div);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Object to encapsulate the fast division+modulus operation.
///
/// This object precomputes two values used to accelerate the computation and is
/// best used when the divisor is a grid-invariant. In this case, it may be
/// computed in host code and marshalled along other kernel arguments using the
/// 'Params' pattern.
///
/// Example:
///
///
/// int quotient, remainder, dividend, divisor;
///
/// FastDivmod divmod(divisor);
///
/// divmod(quotient, remainder, dividend);
///
/// // quotient = (dividend / divisor)
/// // remainder = (dividend % divisor)
///
struct FastDivmod {
int divisor;
unsigned int multiplier;
unsigned int shift_right;
/// Construct the FastDivmod object, in host code ideally.
///
/// This precomputes some values based on the divisor and is computationally
/// expensive.
__host__ __device__ __forceinline__ FastDivmod()
: divisor(0), multiplier(0), shift_right(0) {}
__host__ __device__ __forceinline__ FastDivmod(int divisor_)
: divisor(divisor_) {
find_divisor(multiplier, shift_right, divisor);
}
/// Computes integer division and modulus using precomputed values. This is
/// computationally inexpensive.
__host__ __device__ __forceinline__ void operator()(
int& quotient,
int& remainder,
int dividend) const {
fast_divmod(
quotient, remainder, dividend, divisor, multiplier, shift_right);
}
/// Computes integer division and modulus using precomputed values. This is
/// computationally inexpensive.
///
/// Simply returns the quotient
__host__ __device__ __forceinline__ int divmod(int& remainder, int dividend)
const {
int quotient;
fast_divmod(
quotient, remainder, dividend, divisor, multiplier, shift_right);
return quotient;
}
/// Computes integer division and modulus using precomputed values. This is
/// computationally inexpensive.
__host__ __device__ __forceinline__ void operator()(
int& quotient,
int64_t& remainder,
int64_t dividend) const {
fast_divmod(
quotient, remainder, dividend, divisor, multiplier, shift_right);
}
/// Computes integer division and modulus using precomputed values. This is
/// computationally inexpensive.
__host__ __device__ __forceinline__ int divmod(
int64_t& remainder,
int64_t dividend) const {
int quotient;
fast_divmod(
quotient, remainder, dividend, divisor, multiplier, shift_right);
return quotient;
}
};
} // namespace cu_ctc
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime.h>
#include "include/ctc_prefix_decoder.h"
#include "include/ctc_prefix_decoder_host.h"
#include "device_data_wrap.h"
#include "device_log_prob.cuh"
namespace cu_ctc {
struct InternalData {
cudaStream_t stream;
int lc;
int ldc;
int bs;
int beam;
int ldbeam;
int time;
int ldseq_len;
DeviceDataWrap<float2> pprev;
DeviceDataWrap<float> ptable;
DeviceDataWrap<float> ptablen;
DeviceDataWrap<int> clast;
DeviceDataWrap<int> clen[2];
DeviceDataWrap<int> clist[2];
DeviceDataWrap<int> ptid;
DeviceDataWrap<float> score;
DeviceDataWrap<float> topk_key_buffer;
DeviceDataWrap<int> topk_value_buffer;
DeviceDataWrap<int> select_seqs;
DeviceDataWrap<int> select_seq_lens;
LogProb log_prob;
int max_select_seq_len;
};
std::tuple<size_t, int> calculate_require_buff_and_init_internal_data(
InternalData* inter_data,
int batch_size,
int seq_len,
int vocab_size,
int beam,
std::uintptr_t buff_ptr,
size_t buff_size,
float* log_prob_data_ptr,
int* original_lens,
const std::vector<int>& prob_sizes,
const std::vector<int>& prob_strides,
int blid,
float threshold) {
if ((batch_size * beam * seq_len * vocab_size) <= 0)
return {0, 0};
CHECK(prob_sizes.size() == 3, "only support 3D log_prob.");
CHECK(prob_strides.size() == 3, "only support 3D log_prob. ");
CHECK(
prob_sizes[0] == batch_size && prob_sizes[1] == seq_len &&
prob_sizes[2] == vocab_size,
"batch_size ,seq_len ,vocab_size must match with porb_size");
auto align_size = [](size_t size) -> size_t {
return (size + ALIGN_BYTES - 1) / ALIGN_BYTES * ALIGN_BYTES;
};
int lc = vocab_size;
int ldc = lc;
int ldbeam = ((beam - 1) / 16 + 1) * 16;
int ldseq_len = (seq_len + 16 - 1) / 16 * 16;
int bs = batch_size;
int time = seq_len;
size_t require_size = 0;
size_t pprev_size = sizeof(float2) * bs * ldbeam;
size_t pprev_align_size = align_size(pprev_size);
require_size += pprev_align_size;
size_t ptable_size = sizeof(float) * (bs * beam * ldc);
size_t ptablen_size = sizeof(float) * bs * beam * ldc;
size_t ptable_align_size = align_size(ptable_size);
size_t ptablen_align_size = align_size(ptablen_size);
require_size += ptable_align_size;
require_size += ptablen_align_size;
size_t clast_align_size = align_size(sizeof(int) * ldbeam * bs);
require_size += clast_align_size;
size_t clen_align_size = align_size(sizeof(int) * ldbeam * bs);
size_t clist_align_size = align_size(sizeof(int) * ldseq_len * beam * bs);
require_size += 2 * clen_align_size;
require_size += 2 * clist_align_size;
size_t ptid_align_size = align_size(sizeof(int) * bs * ldbeam);
require_size += ptid_align_size;
size_t score_align_size = align_size(sizeof(float) * bs * ldbeam);
require_size += score_align_size;
size_t key_buff_align_size = align_size(sizeof(float) * beam * MAX_BLOCKS);
size_t value_buff_align_size = align_size(sizeof(int) * beam * MAX_BLOCKS);
require_size += (key_buff_align_size + value_buff_align_size);
size_t select_seqs_align_size =
align_size(sizeof(int) * batch_size * seq_len);
require_size += select_seqs_align_size;
size_t select_seq_lens_align_size = align_size(sizeof(int) * batch_size);
require_size += select_seq_lens_align_size;
require_size += ALIGN_BYTES;
if (require_size > buff_size)
return {require_size, 0};
char* buff_align_ptr = reinterpret_cast<char*>(align_size(buff_ptr));
inter_data->beam = beam;
inter_data->ldbeam = ldbeam;
inter_data->bs = bs;
inter_data->lc = lc;
inter_data->ldc = ldc;
inter_data->time = time;
inter_data->ldseq_len = ldseq_len;
#define SET_DATA(NAME, TYPE, SIZE) \
inter_data->NAME = \
DeviceDataWrap<TYPE>(reinterpret_cast<TYPE*>(buff_align_ptr), SIZE); \
buff_align_ptr += SIZE;
SET_DATA(pprev, float2, pprev_align_size);
SET_DATA(ptable, float, ptable_align_size);
SET_DATA(ptablen, float, ptable_align_size);
SET_DATA(clast, int, clast_align_size);
SET_DATA(clen[0], int, clen_align_size);
SET_DATA(clen[1], int, clen_align_size);
SET_DATA(clist[0], int, clist_align_size);
SET_DATA(clist[1], int, clist_align_size);
SET_DATA(ptid, int, ptid_align_size);
SET_DATA(score, float, score_align_size);
SET_DATA(topk_key_buffer, float, key_buff_align_size);
SET_DATA(topk_value_buffer, int, value_buff_align_size);
SET_DATA(select_seqs, int, select_seqs_align_size);
SET_DATA(select_seq_lens, int, select_seq_lens_align_size);
#undef SET_DATA
// init log_prob
inter_data->log_prob.data_ptr = log_prob_data_ptr;
inter_data->log_prob.origin_seq_lens = original_lens;
inter_data->log_prob.select_seqs = inter_data->select_seqs.data_ptr();
inter_data->log_prob.select_seq_lens = inter_data->select_seq_lens.data_ptr();
inter_data->log_prob.batch = batch_size;
inter_data->log_prob.vocab_size = vocab_size;
inter_data->log_prob.seq_len = seq_len;
inter_data->log_prob.batch_stride = prob_strides[0];
inter_data->log_prob.seq_len_stride = prob_strides[1];
inter_data->log_prob.vocab_stride = prob_strides[2];
inter_data->max_select_seq_len = init_log_prob_and_cal_max_select_seq_len(
&(inter_data->log_prob), blid, threshold, inter_data->stream);
return {0, inter_data->max_select_seq_len};
}
int prefixCTC_V2(
InternalData* inter_data,
int blid,
int spid,
int step,
bool is_last_step,
int max_select_seq_len) {
LogProb* log_prob_struct = &(inter_data->log_prob);
if (step == 0) {
CTC_prob_first_step_V2(
log_prob_struct,
step,
inter_data->pprev,
inter_data->ptid,
inter_data->clast,
inter_data->clen[step % 2],
inter_data->clist[step % 2],
inter_data->beam,
inter_data->ldbeam,
inter_data->ldseq_len,
inter_data->bs,
inter_data->score,
inter_data->stream,
blid);
} else {
CTC_prob_matrix_V2(
log_prob_struct,
step,
inter_data->pprev,
inter_data->ptable,
inter_data->ptablen,
inter_data->clast,
inter_data->lc,
inter_data->ldc,
inter_data->beam,
inter_data->ldbeam,
inter_data->bs,
blid,
spid,
inter_data->stream);
CTC_prob_merge_V2(
log_prob_struct,
step,
inter_data->ptable,
inter_data->ptablen,
inter_data->ptid,
inter_data->clast,
inter_data->clist[(step % 2) ^ 1],
inter_data->clen[(step % 2) ^ 1],
inter_data->lc,
inter_data->ldc,
inter_data->beam,
inter_data->ldbeam,
inter_data->ldseq_len,
inter_data->bs,
inter_data->stream,
blid);
CTC_prob_topK_V2(
log_prob_struct,
step,
inter_data->pprev,
inter_data->ptable,
inter_data->ptablen,
inter_data->ptid,
inter_data->clast,
inter_data->clen[(step % 2) ^ 1],
inter_data->clen[(step % 2)],
inter_data->clist[(step % 2) ^ 1],
inter_data->clist[(step % 2)],
inter_data->lc,
inter_data->ldc,
inter_data->beam,
inter_data->ldbeam,
inter_data->ldseq_len,
blid,
inter_data->bs,
inter_data->score,
inter_data->topk_key_buffer,
inter_data->topk_value_buffer,
inter_data->stream,
is_last_step);
if (is_last_step) {
// if the parity of select_seq_len is different from the
// max_select_seq_len, their clist and clen need to be copy to another
// clist and clen
CTC_copy_list_len_for_differnet_parity(
log_prob_struct,
step,
max_select_seq_len,
inter_data->clen[(step % 2) ^ 1],
inter_data->clen[(step % 2)],
inter_data->clist[(step % 2) ^ 1],
inter_data->clist[(step % 2)],
inter_data->bs,
inter_data->beam,
inter_data->ldbeam,
inter_data->ldseq_len,
inter_data->stream);
}
}
return 0;
}
std::uintptr_t prefixCTC_alloc(std::uintptr_t stream_ptr) {
InternalData* Inter_data = new InternalData;
Inter_data->stream = reinterpret_cast<cudaStream_t>(stream_ptr);
return reinterpret_cast<std::uintptr_t>(Inter_data);
}
void prefixCTC_free(std::uintptr_t inter_data_ptr) {
InternalData* inter_data = reinterpret_cast<InternalData*>(inter_data_ptr);
delete inter_data;
}
int ctc_beam_search_decoder_batch_gpu(
InternalData* inter_data,
float* pp,
int blid,
int spid,
int* clist,
int* clen,
float* score) {
// batch_pprev: time x batch x lc
// internal_data *data = (internal_data *)data_int;
CUDA_CHECK(cudaMemsetAsync(
(inter_data->clast.data_ptr()),
0,
inter_data->clast.size_in_byte(),
inter_data->stream));
CUDA_CHECK(cudaMemsetAsync(
(inter_data->clen[0].data_ptr()),
0,
inter_data->clen[0].size_in_byte(),
inter_data->stream));
CUDA_CHECK(cudaMemsetAsync(
(inter_data->clen[1].data_ptr()),
0,
inter_data->clen[0].size_in_byte(),
inter_data->stream));
CUDA_CHECK(cudaMemsetAsync(
(inter_data->clist[0].data_ptr()),
-1,
inter_data->clen[0].size_in_byte(),
inter_data->stream));
CUDA_CHECK(cudaMemsetAsync(
(inter_data->clist[1].data_ptr()),
-1,
inter_data->clen[0].size_in_byte(),
inter_data->stream));
// ptable the table of prob for end_in_bank (bs*beam*vocab_size)
// ptablen the table of prob for no_end_in_bank(ba*beam*vocab_size)
int step = 0;
while (step < inter_data->max_select_seq_len) {
bool is_last_step = (step == (inter_data->max_select_seq_len - 1));
prefixCTC_V2(
inter_data,
blid,
spid,
step,
is_last_step,
inter_data->max_select_seq_len);
step++;
}
CUDA_CHECK(cudaMemcpy2DAsync(
clen,
sizeof(int) * inter_data->beam,
inter_data->clen[(step % 2) ^ 1].data_ptr(),
sizeof(int) * inter_data->ldbeam,
sizeof(int) * inter_data->beam,
inter_data->bs,
cudaMemcpyDeviceToHost,
inter_data->stream));
CUDA_CHECK(cudaMemcpy2DAsync(
clist,
sizeof(int) * inter_data->max_select_seq_len,
inter_data->clist[(step % 2) ^ 1].data_ptr(),
sizeof(int) * inter_data->ldseq_len,
sizeof(int) * inter_data->max_select_seq_len,
inter_data->beam * inter_data->bs,
cudaMemcpyDeviceToHost,
inter_data->stream));
CUDA_CHECK(cudaMemcpy2DAsync(
score,
sizeof(float) * inter_data->beam,
inter_data->score.data_ptr(),
sizeof(float) * inter_data->ldbeam,
sizeof(float) * inter_data->beam,
inter_data->bs,
cudaMemcpyDeviceToHost,
inter_data->stream));
CUDA_CHECK(cudaStreamSynchronize(inter_data->stream));
return 0;
}
} // namespace cu_ctc
This diff is collapsed.
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <iostream>
#include <vector>
#include "include/ctc_prefix_decoder_host.h"
namespace cu_ctc {
constexpr size_t ALIGN_BYTES = 128;
constexpr int MAX_BLOCKS = 800;
template <typename T>
class DeviceDataWrap {
public:
DeviceDataWrap() : data_{}, size_in_bytes_{} {};
DeviceDataWrap(T* data_ptr, size_t size_in_byte)
: data_{data_ptr}, size_in_bytes_{size_in_byte} {};
void print(size_t offset, size_t size_in_element, int eles_per_row = 10)
const {
if ((offset + size_in_element) * sizeof(T) > size_in_bytes_) {
std::cerr
<< " ERROR: in DeviceDataWrap print : offset+size_in_element > size_in_bytes_";
abort();
}
std::vector<T> host_data(size_in_element);
CUDA_CHECK(cudaMemcpy(
host_data.data(),
data_ + offset,
size_in_element * sizeof(T),
cudaMemcpyDeviceToHost));
for (int i = 0; i < size_in_element; ++i) {
if (i != 0 && (i % eles_per_row == 0)) {
std::cout << " \n";
}
std::cout << "[" << i << "]:" << host_data[i] << " ";
}
std::cout << "\n";
}
operator T*() {
return data_;
}
operator const T*() {
return const_cast<const T*>(data_);
}
T* data_ptr() const {
return data_;
}
size_t size_in_byte() const {
return size_in_bytes_;
}
void set_data_ptr(T* data_ptr) {
data_ = data_ptr;
}
void set_size_in_byte(size_t size_in_byte) {
size_in_bytes_ = size_in_byte;
}
private:
T* data_;
size_t size_in_bytes_;
};
} // namespace cu_ctc
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
namespace cu_ctc {
struct LogProb {
float* data_ptr;
int batch;
int seq_len;
int vocab_size;
int batch_stride;
int seq_len_stride;
int vocab_stride;
int* origin_seq_lens; // batchs
int* select_seqs; // batchs *seq_len;
int* select_seq_lens; // batchs
__device__ __forceinline__ float at(int batch_id, int seq_id, int char_id) {
return data_ptr
[batch_id * batch_stride + seq_id * seq_len_stride +
char_id * vocab_stride];
}
__device__ __forceinline__ int ith_selected_seq_in_this_batch(
int batch_id,
int i) {
return select_seqs[batch_id * seq_len + i];
}
__device__ __forceinline__ bool need_process_on_ith_step(
int batch_id,
int istep) {
return istep < select_seq_lens[batch_id];
}
/**
* @brief if the prob of blank in next original timestep > threshold , we
* will not process the next original timestep, but will process the
* subsequent blank on the currently processed timestep.
*
* @param batch_id
* @param istep
* @return __device__
*/
__device__ __forceinline__ bool need_add_blank(int batch_id, int istep) {
if ((istep < 0) || (istep + 1) >= select_seq_lens[batch_id]) {
return false;
}
if ((ith_selected_seq_in_this_batch(batch_id, istep + 1) -
ith_selected_seq_in_this_batch(batch_id, istep)) > 1) {
return true;
}
return false;
}
};
} // namespace cu_ctc
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tuple>
#include <utility>
#include <vector>
#include "include/ctc_prefix_decoder.h"
namespace py = pybind11;
std::tuple<size_t, std::vector<std::vector<std::pair<float, std::vector<int>>>>>
ctc_prefix_decoder_batch_wrapper(
std::uintptr_t n_inter_data,
std::uintptr_t buff_ptr,
size_t buff_size,
std::uintptr_t pp,
std::uintptr_t seq_len_ptr,
const std::vector<int>& pp_sizes,
const std::vector<int>& pp_strides,
int beam,
int blid,
int spid,
float thresold) {
using SCORE_TYPE =
std::vector<std::vector<std::pair<float, std::vector<int>>>>;
cu_ctc::InternalData* inter_data = (cu_ctc::InternalData*)(n_inter_data);
auto [require_size, max_select_seq_len] =
cu_ctc::calculate_require_buff_and_init_internal_data(
inter_data,
pp_sizes[0],
pp_sizes[1],
pp_sizes[2],
beam,
buff_ptr,
buff_size,
(float*)pp,
(int*)seq_len_ptr,
pp_sizes,
pp_strides,
blid,
thresold);
if (require_size > 0) {
return std::make_tuple(require_size, SCORE_TYPE{});
}
int batch_size = pp_sizes[0];
std::vector<int> list_data(batch_size * beam * max_select_seq_len);
std::vector<int> len_data(batch_size * beam);
std::vector<float> score(batch_size * beam);
cu_ctc::ctc_beam_search_decoder_batch_gpu(
inter_data,
(float*)pp,
blid,
spid,
list_data.data(),
len_data.data(),
score.data());
SCORE_TYPE score_hyps{};
score_hyps.reserve(batch_size);
for (int b = 0; b < batch_size; b++) {
score_hyps.push_back(std::vector<std::pair<float, std::vector<int>>>{});
score_hyps.back().reserve(beam);
for (int beam_id = 0; beam_id < beam; beam_id++) {
int len = len_data[b * beam + beam_id];
int offset = b * beam * max_select_seq_len + beam_id * max_select_seq_len;
std::vector<int> clist(
list_data.data() + offset, list_data.data() + offset + len);
score_hyps.back().push_back(
std::pair{score[b * beam + beam_id], std::move(clist)});
}
}
return std::make_tuple(require_size, std::move(score_hyps));
}
PYBIND11_MODULE(pybind11_prefixctc, m) {
m.doc() = "none";
m.def(
"ctc_beam_search_decoder_batch_gpu_v2",
&ctc_prefix_decoder_batch_wrapper,
"ctc prefix decoder v2 computing on GPU");
m.def("prefixCTC_alloc", &cu_ctc::prefixCTC_alloc, "allocate internal data");
m.def("prefixCTC_free", &cu_ctc::prefixCTC_free, "free internal data");
}
...@@ -6,6 +6,11 @@ _CTC_DECODERS = [ ...@@ -6,6 +6,11 @@ _CTC_DECODERS = [
"ctc_decoder", "ctc_decoder",
"download_pretrained_files", "download_pretrained_files",
] ]
_CUDA_CTC_DECODERS = [
"CUCTCDecoder",
"CUCTCHypothesis",
"cuda_ctc_decoder",
]
def __getattr__(name: str): def __getattr__(name: str):
...@@ -20,6 +25,17 @@ def __getattr__(name: str): ...@@ -20,6 +25,17 @@ def __getattr__(name: str):
item = getattr(_ctc_decoder, name) item = getattr(_ctc_decoder, name)
globals()[name] = item globals()[name] = item
return item return item
elif name in _CUDA_CTC_DECODERS:
try:
from . import _cuda_ctc_decoder
except AttributeError as err:
raise RuntimeError(
"To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
) from err
item = getattr(_cuda_ctc_decoder, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -27,4 +43,4 @@ def __dir__(): ...@@ -27,4 +43,4 @@ def __dir__():
return sorted(__all__) return sorted(__all__)
__all__ = _CTC_DECODERS __all__ = [_CTC_DECODERS, _CUDA_CTC_DECODERS]
from __future__ import annotations
import math
from typing import List, NamedTuple, Union
import torch
import torchaudio
torchaudio._extension._load_lib("libctc_prefix_decoder")
import torchaudio.lib.pybind11_prefixctc as cuctc
__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
def _get_vocab_list(vocab_file):
vocab = []
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().split()
vocab.append(line[0])
return vocab
class CUCTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
tokens: List[int]
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
words: List[str]
"""List of predicted tokens. Algin with modeling unit.
"""
score: float
"""Score corresponding to hypothesis"""
_DEFAULT_SKIP_THREASHOLD = math.log(0.95)
class CUCTCDecoder:
"""CUDA CTC beam search decoder.
.. devices:: CUDA
Note:
To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
"""
def __init__(
self,
vocab_list: List[str],
blank_id: int = 0,
beam_size: int = 10,
nbest: int = 1,
blank_skip_threshold: float = _DEFAULT_SKIP_THREASHOLD,
cuda_stream: torch.cuda.streams.Stream = None,
):
"""
Args:
blank_id (int): token id corresopnding to blank (Default: 0)
vocab_list (List[str]): list of vocabulary tokens
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
nbest (int): number of best decodings to return
blank_skip_threshold (float): skip frames if log_prob(blank) > blank_skip_threshold, to speed up decoding.
(Default: log(0.95)).
cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
"""
if cuda_stream:
if not isinstance(cuda_stream, torch.cuda.streams.Stream):
raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
self.blank_id = 0 # blank id has to be zero
self.vocab_list = vocab_list
self.space_id = 0
self.nbest = nbest
self.blank_skip_threshold = blank_skip_threshold
self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
def __del__(self):
if cuctc is not None:
cuctc.prefixCTC_free(self.internal_data)
def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
"""
Args:
log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; log_softmax(output of acoustic model).
lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch.
Returns:
List[List[CUCTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
if not encoder_out_lens.dtype == torch.int32:
raise AssertionError("encoder_out_lens must be torch.int32")
if not log_prob.dtype == torch.float32:
raise AssertionError("log_prob must be torch.float32")
if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
raise AssertionError("inputs must be cuda tensors")
if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
raise AssertionError("input tensors must be contiguous")
required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
self.internal_data,
self.memory.data_ptr(),
self.memory.size(0),
log_prob.data_ptr(),
encoder_out_lens.data_ptr(),
log_prob.size(),
log_prob.stride(),
self.beam_size,
self.blank_id,
self.space_id,
self.blank_skip_threshold,
)
if required_size > 0:
self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
_, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
self.internal_data,
self.memory.data_ptr(),
self.memory.size(0),
log_prob.data_ptr(),
encoder_out_lens.data_ptr(),
log_prob.size(),
log_prob.stride(),
self.beam_size,
self.blank_id,
self.space_id,
self.blank_skip_threshold,
)
batch_size = len(score_hyps)
hypos = []
for i in range(batch_size):
hypos.append(
[
CUCTCHypothesis(
tokens=score_hyps[i][j][1],
words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
score=score_hyps[i][j][0],
)
for j in range(self.nbest)
]
)
return hypos
def cuda_ctc_decoder(
tokens: Union[str, List[str]],
nbest: int = 1,
beam_size: int = 10,
blank_skip_threshold: float = _DEFAULT_SKIP_THREASHOLD,
) -> CUCTCDecoder:
"""Builds an instance of :class:`CUCTCDecoder`.
Args:
tokens (str or List[str]): File or list containing valid tokens.
If using a file, the expected format is for tokens mapping to the same index to be on the same line
beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
nbest (int): The number of best decodings to return
blank_id (int): The token ID corresopnding to the blank symbol.
blank_skip_threshold (float): skip frames if log_prob(blank) > blank_skip_threshold, to speed up decoding
(Default: log(0.95)).
Returns:
CUCTCDecoder: decoder
Example
>>> decoder = cuda_ctc_decoder(
>>> vocab_file="tokens.txt",
>>> blank_skip_threshold=math.log(0.95),
>>> )
>>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
"""
if type(tokens) == str:
tokens = _get_vocab_list(tokens)
return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
...@@ -4,7 +4,6 @@ It affects functionalities in :py:mod:`torchaudio.io` (and indirectly :py:func:` ...@@ -4,7 +4,6 @@ It affects functionalities in :py:mod:`torchaudio.io` (and indirectly :py:func:`
""" """
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch
import torchaudio import torchaudio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment