#pragma once #include #include #include #include #include #include #include namespace torchaudio { namespace rnnt { namespace cpu { template struct LogProbs { DTYPE skip_; // blank. DTYPE emit_; // target. LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {} DTYPE& skip() { return skip_; } DTYPE& emit() { return emit_; } const DTYPE& skip() const { return skip_; } const DTYPE& emit() const { return emit_; } }; // TensorView: view a block of allocated memory as a tensor. template class TensorView { public: TensorView(const std::vector& dims, DTYPE* data) : dims_(dims), data_(data) { strides_.resize(dims.size()); strides_.back() = 1; for (int i = dims.size() - 2; i >= 0; --i) { strides_[i] = strides_[i + 1] * dims[i + 1]; } } DTYPE& operator()(const std::vector& indices) { TORCH_CHECK_EQ(indices.size(), dims_.size()); int index = indices.back(); for (int i = indices.size() - 2; i >= 0; --i) { index += indices[i] * strides_[i]; } return data_[index]; } void SetZero() { int size = dims_[0] * strides_[0]; std::memset(data_, 0, sizeof(DTYPE) * size); } private: std::vector dims_; std::vector strides_; DTYPE* data_; }; template status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) { for (int i = 0; i < N * D; i += D) { CAST_DTYPE max = logits[i]; for (int j = 1; j < D; ++j) { max = std::max(max, CAST_DTYPE(logits[i + j])); } CAST_DTYPE sum = 0; for (int j = 0; j < D; ++j) { sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max); } outputs[i / D] = max + std::log(sum); } return SUCCESS; } template void ComputeLogProbsOneSequence( const Options& options, TensorView& logits, const int* targets, int srcLen, int tgtLen, TensorView& denom, TensorView>& logProbs) { const int& T = srcLen; const int& U = tgtLen; const int& blank = options.blank_; const bool& fusedLogSmax = options.fusedLogSmax_; for (int t = 0; t < T; ++t) { for (int u = 0; u < U; ++u) { if (u < U - 1) { logProbs({t, u}).emit() = CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u}); } logProbs({t, u}).skip() = CAST_DTYPE(logits({t, u, blank})) - denom({t, u}); if (!fusedLogSmax) { if (u < U - 1) { logProbs({t, u}).emit() = CAST_DTYPE(logits({t, u, targets[u]})); } logProbs({t, u}).skip() = CAST_DTYPE(logits({t, u, blank})); } } } } template status_t ComputeLogProbs( const Options& options, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, const CAST_DTYPE* denominators, CAST_DTYPE* logProbs) { std::vector> seqLogits; std::vector seqTargets; std::vector> seqDenoms; std::vector>> seqlogProbs; const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; const int& D = options.numTargets_; for (int b = 0; b < B; ++b) { seqLogits.push_back( TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); seqTargets.push_back(targets + b * (maxU - 1)); seqDenoms.push_back(TensorView( {maxT, maxU}, denominators + b * maxT * maxU)); seqlogProbs.push_back(TensorView>( {maxT, maxU}, reinterpret_cast*>(logProbs) + b * maxT * maxU)); } //#pragma omp parallel for for (int b = 0; b < B; ++b) { // use max 2 * B threads. ComputeLogProbsOneSequence( /*options=*/options, /*logits=*/seqLogits[b], /*targets=*/seqTargets[b], /*srcLen=*/srcLengths[b], /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. /*denom=*/seqDenoms[b], /*logProbs=*/seqlogProbs[b]); } return SUCCESS; } template DTYPE ComputeAlphaOneSequence( TensorView>& logProbs, int srcLen, int tgtLen, TensorView& alpha) { const int& T = srcLen; const int& U = tgtLen; alpha({0, 0}) = DTYPE(0); for (int t = 1; t < T; ++t) { // u == 0. alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip(); } for (int u = 1; u < U; ++u) { // t == 0. alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit(); } for (int t = 1; t < T; ++t) { for (int u = 1; u < U; ++u) { alpha({t, u}) = math::lse( alpha({t - 1, u}) + logProbs({t - 1, u}).skip(), alpha({t, u - 1}) + logProbs({t, u - 1}).emit()); } } DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip(); return forward_score; } template DTYPE ComputeBetaOneSequence( TensorView>& logProbs, int srcLen, int tgtLen, TensorView& beta) { const int& T = srcLen; const int& U = tgtLen; beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip(); for (int t = T - 2; t >= 0; --t) { // u == U - 1. beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip(); } for (int u = U - 2; u >= 0; --u) { // t == T - 1. beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit(); } for (int t = T - 2; t >= 0; --t) { for (int u = U - 2; u >= 0; --u) { beta({t, u}) = math::lse( beta({t + 1, u}) + logProbs({t, u}).skip(), beta({t, u + 1}) + logProbs({t, u}).emit()); } } DTYPE backward_score = beta({0, 0}); return backward_score; } template DTYPE ComputeAlphaOrBetaOneSequence( int thread, const Options& options, TensorView>& logProbs, int srcLen, int tgtLen, TensorView& alpha, TensorView& beta) { if (thread & 1) { return ComputeAlphaOneSequence( /*logProbs=*/logProbs, /*srcLen=*/srcLen, /*tgtLen=*/tgtLen, /*alpha=*/alpha); } else { return ComputeBetaOneSequence( /*logProbs=*/logProbs, /*srcLen=*/srcLen, /*tgtLen=*/tgtLen, /*beta=*/beta); } } template void ComputeAlphasBetas( const Options& options, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, CAST_DTYPE* alphas, CAST_DTYPE* betas, DTYPE* costs) { std::vector>> seqlogProbs; std::vector> seq_alphas; std::vector> seq_betas; const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; for (int b = 0; b < B; ++b) { seqlogProbs.push_back(TensorView>( {maxT, maxU}, reinterpret_cast*>( const_cast(logProbs)) + b * maxT * maxU)); seq_alphas.push_back( TensorView({maxT, maxU}, alphas + b * maxT * maxU)); seq_betas.push_back( TensorView({maxT, maxU}, betas + b * maxT * maxU)); } std::vector scores(B << 1); //#pragma omp parallel for for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads. int i = (t >> 1); scores[t] = ComputeAlphaOrBetaOneSequence( /*thread=*/t, /*options=*/options, /*logProbs=*/seqlogProbs[i], /*srcLen=*/srcLengths[i], /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. /*alpha=*/seq_alphas[i], /*beta=*/seq_betas[i]); } for (int b = 0; b < B; ++b) { costs[b] = -scores[b << 1]; } } template void ComputeGradientsOneSequence( const Options& options, TensorView& logits, const int* targets, int srcLen, int tgtLen, TensorView& denom, TensorView& alpha, TensorView& beta, TensorView& gradients) { // don't set gradients to zero to here as gradients might reuse memory from // logits const int& T = srcLen; const int& U = tgtLen; const int& D = options.numTargets_; const int& blank = options.blank_; const CAST_DTYPE clamp = options.clamp_; const bool& fusedLogSmax = options.fusedLogSmax_; CAST_DTYPE cost = -beta({0, 0}); if (fusedLogSmax) { // Note - below gradient is different from numpy_transducer, since we // compute log_softmax more efficiently within the loss, to save memory The // details of the below implementation / equations can be found in Sec 3.2 // (function merging) in below paper: // https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf for (int t = 0; t < T; ++t) { for (int u = 0; u < U; ++u) { CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u}); for (int d = 0; d < D; ++d) { CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c; if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g); } else if (d == blank && t < T - 1) { gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u})); } else if (u < U - 1 && d == targets[u]) { gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1})); } else { gradients({t, u, d}) = std::exp(g + beta({t, u})); } if (clamp > 0) { gradients({t, u, d}) = math::min(CAST_DTYPE(gradients({t, u, d})), clamp); gradients({t, u, d}) = math::max(CAST_DTYPE(gradients({t, u, d})), -clamp); } } } } } else { for (int t = 0; t < T; ++t) { for (int u = 0; u < U; ++u) { for (int d = 0; d < D; ++d) { CAST_DTYPE g = cost + CAST_DTYPE(logits({t, u, d})); if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. gradients({t, u, d}) = g + alpha({t, u}); } else if (d == blank && t < T - 1) { gradients({t, u, d}) = g + alpha({t, u}) + beta({t + 1, u}); } else if (u < U - 1 && d == targets[u]) { gradients({t, u, d}) = g + alpha({t, u}) + beta({t, u + 1}); } else { gradients({t, u, d}) = g + CAST_DTYPE(-INFINITY); } gradients({t, u, d}) = -(std::exp(gradients({t, u, d}))); if (clamp > 0) { gradients({t, u, d}) = math::min(CAST_DTYPE(gradients({t, u, d})), clamp); gradients({t, u, d}) = math::max(CAST_DTYPE(gradients({t, u, d})), -clamp); } } } } } // zero out the rest of the gradients, necessary when reusing logits memory // check the memory location to see if it's necessary if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) { const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; for (int t = T; t < maxT; ++t) { for (int u = 0; u < maxU; ++u) { for (int d = 0; d < D; ++d) { gradients({t, u, d}) = 0.; } } } for (int t = 0; t < T; ++t) { for (int u = U; u < maxU; ++u) { for (int d = 0; d < D; ++d) { gradients({t, u, d}) = 0.; } } } } } template void ComputeGradients( const Options& options, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, const CAST_DTYPE* denominators, const CAST_DTYPE* alphas, const CAST_DTYPE* betas, DTYPE* gradients) { std::vector> seqLogits; std::vector seqTargets; std::vector> seqDenoms; std::vector> seq_alphas; std::vector> seq_betas; std::vector> seq_gradients; const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; const int& D = options.numTargets_; for (int b = 0; b < B; ++b) { seqLogits.push_back( TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); seqTargets.push_back(targets + b * (maxU - 1)); seqDenoms.push_back(TensorView( {maxT, maxU}, denominators + b * maxT * maxU)); seq_alphas.push_back( TensorView({maxT, maxU}, alphas + b * maxT * maxU)); seq_betas.push_back( TensorView({maxT, maxU}, betas + b * maxT * maxU)); seq_gradients.push_back( TensorView({maxT, maxU, D}, gradients + b * maxT * maxU * D)); } //#pragma omp parallel for for (int b = 0; b < B; ++b) { // use max 2 * B threads. ComputeGradientsOneSequence( /*options=*/options, /*logits=*/seqLogits[b], /*targets=*/seqTargets[b], /*srcLen=*/srcLengths[b], /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. /*denom=*/seqDenoms[b], /*alpha=*/seq_alphas[b], /*beta=*/seq_betas[b], /*gradients=*/seq_gradients[b]); } } template void ComputeAlphas( const Options& options, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, CAST_DTYPE* alphas) { std::vector>> seqlogProbs; std::vector> seq_alphas; const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; for (int b = 0; b < B; ++b) { seqlogProbs.push_back(TensorView>( {maxT, maxU}, reinterpret_cast*>( const_cast(logProbs)) + b * maxT * maxU)); seq_alphas.push_back( TensorView({maxT, maxU}, alphas + b * maxT * maxU)); } //#pragma omp parallel for for (int i = 0; i < B; ++i) { // use max 2 * B threads. ComputeAlphaOneSequence( /*logProbs=*/seqlogProbs[i], /*srcLen=*/srcLengths[i], /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. /*alpha=*/seq_alphas[i]); } } template void ComputeBetas( const Options& options, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, CAST_DTYPE* costs, CAST_DTYPE* betas) { std::vector>> seqlogProbs; std::vector> seq_betas; const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; for (int b = 0; b < B; ++b) { seqlogProbs.push_back(TensorView>( {maxT, maxU}, reinterpret_cast*>( const_cast(logProbs)) + b * maxT * maxU)); seq_betas.push_back( TensorView({maxT, maxU}, betas + b * maxT * maxU)); } //#pragma omp parallel for for (int i = 0; i < B; ++i) { ComputeBetaOneSequence( /*logProbs=*/seqlogProbs[i], /*srcLen=*/srcLengths[i], /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. /*betas=*/seq_betas[i]); } } } // namespace cpu } // namespace rnnt } // namespace torchaudio