Commit d946a7c8 authored by flyingdown's avatar flyingdown
Browse files

rnnt

parent b90d7988
......@@ -51,6 +51,16 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif()
if (USE_ROCM)
list(
APPEND
LIBTORCHAUDIO_SOURCES
rnnt/dcu/compute_alphas.cpp
rnnt/dcu/compute_betas.cpp
rnnt/dcu/compute.cpp
)
endif()
endif()
if(USE_CUDA)
......@@ -72,6 +82,25 @@ if(USE_CUDA)
)
endif()
if(USE_ROCM)
list(
APPEND
LIBTORCHAUDIO_INCLUDE_DIRS
${CUDA_TOOLKIT_INCLUDE}
)
list(
APPEND
LIBTORCHAUDIO_LINK_LIBRARIES
${C10_CUDA_LIBRARY}
${CUDA_CUDART_LIBRARY}
)
list(
APPEND
LIBTORCHAUDIO_COMPILE_DEFINITIONS
USE_ROCM
)
endif()
if(BUILD_KALDI)
list(APPEND LIBTORCHAUDIO_LINK_LIBRARIES kaldi)
list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp)
......
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK(
target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");
TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::hip::getCurrentHIPStream();
hipSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};
return std::make_tuple(costs, gradients);
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss", &compute);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::hip::getCurrentHIPStream();
hipSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::hip::getCurrentHIPStream();
hipSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) {
__shared__ CAST_DTYPE shared[NUM_THREADS];
// each thread reduces one matrix row
int offset = blockIdx.x * dim; // [n, 0]
CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0)
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
CAST_DTYPE next = inputs[offset + d];
if (next > val) {
val = next;
}
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shared[threadIdx.x + stride] > shared[threadIdx.x]) {
shared[threadIdx.x] = shared[threadIdx.x + stride];
val = shared[threadIdx.x];
}
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down(val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
}
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = val;
}
}
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceLogSumExpGivenMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) { // in: max -> out: logsum
__shared__ CAST_DTYPE shared[NUM_THREADS];
CAST_DTYPE max = outputs[blockIdx.x];
CAST_DTYPE val = 0;
int offset = blockIdx.x * dim;
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max);
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = shared[threadIdx.x] + shared[threadIdx.x + stride];
shared[threadIdx.x] = val;
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down(val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = max + std::log(val);
}
}
} // namespace rnnt
} // namespace torchaudio
#endif // USE_ROCM
#pragma once
#ifdef USE_ROCM
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/kernels.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeLogProbs(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
if (t >= T || u >= U) { // out of boundary.
return;
}
Indexer3D indexer(maxT, maxU);
int idx = indexer(bTgt, t, u);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeAlphas(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x + 1;
const int u = blockIdx.y + 1;
if (t >= T || u >= U) { // out of boundary.
return;
}
int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == 1 && u == 1) {
alphas[idxr(bTgt, 0, 0)] = 0;
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == 1 && u < U) {
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] +
logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t < T) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)];
alphas[idxr(bTgt, t, 0)] = skip_prob + val;
}
if (t < T && u < U) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip =
alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob;
CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up(val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
alphas[idxr(bTgt, t, u)] = out;
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x;
const int u = U - 2 - blockIdx.y;
if (t < 0 || u < 0) { // out of boundary.
return;
}
int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == T - 2 && u == U - 2) {
betas[idxr(bTgt, T - 1, U - 1)] =
logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == T - 2 && u >= 0) {
betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] +
logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up(skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
betas[idxr(bTgt, t, U - 1)] =
betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob;
}
if (t >= 0 && u >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob;
CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up(val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
betas[idxr(bTgt, t, u)] = out;
if (t == 0 && u == 0) { // use -beta(0, 0) as cost.
costs[bTgt] = DTYPE(-out);
}
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int warpSize = 0,
int numWarps = 0,
int H = 1) {
assert(threadIdx.y == 0 || threadIdx.y == 1);
if (threadIdx.y == 0) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/alpha_counters,
/*alphas=*/alphas,
H);
} else { // threadIdx.y == 1
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*betaCounters=*/betaCounters,
/*beta=*/betas,
/*costs=*/costs,
H);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeGradients(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
ComputeGradientsElement(
bTgt,
t,
u,
maxSrcLen,
maxTgtLen,
numTargets,
blank,
clamp,
logits,
targets,
srcLengths,
tgtLengths,
denominators,
alphas,
betas,
gradients,
H);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
alpha_counters,
alphas,
H);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeBetasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
betaCounters,
betas,
costs,
H);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
} // namespace rnnt
} // namespace torchaudio
#endif // USE_ROCM
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh>
namespace torchaudio {
namespace rnnt {
namespace gpu {
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(
hipError_t code,
const char* file,
int line,
bool abort = true) {
if (code != hipSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
hipGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(
hipStream_t stream,
int N,
int D,
const DTYPE* logits, // [N, D]
CAST_DTYPE* outputs) {
{ // compute max among D.
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if (hipGetLastError() != hipSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED;
}
}
{ // compute log(sum(exp(d_i - max)))
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceLogSumExpGivenMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED;
}
}
return SUCCESS;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
const hipStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas, betas and costs.
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3 thread_dims(WARP_SIZE, 2);
ComputeAlphasBetasCosts<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*beta_counters=*/workspace.GetPointerToBetaCounters(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs,
/*warp_size=*/WARP_SIZE,
/*num_warps=*/num_warps,
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
if (gradients != nullptr) { // compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int num_blocks =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_blocks, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeGradients<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*clamp=*/clamp,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
const hipStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3 thread_dims(WARP_SIZE, 1);
ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/(volatile DTYPE*)alphas,
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
const hipStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute betas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3 thread_dims(WARP_SIZE, 1);
ComputeBetasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToBetaCounters(),
/*alphas=*/(volatile DTYPE*)betas,
costs,
H);
if (hipGetLastError() != hipSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#endif // USE_ROCM
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace torchaudio {
namespace rnnt {
struct alignas(sizeof(__half)) Half {
__half x;
HOST_AND_DEVICE Half() = default;
FORCE_INLINE HOST_AND_DEVICE Half(float f) {
x = __float2half_rn(f);
if (isinf(__half2float(x))) {
x = __float2half_rz(f); // round toward 0.
}
}
FORCE_INLINE HOST_AND_DEVICE operator float() const {
return __half2float(x);
}
FORCE_INLINE HOST_AND_DEVICE Half(__half f) {
x = f;
}
FORCE_INLINE HOST_AND_DEVICE operator __half() const {
return x;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace torchaudio {
namespace rnnt {
inline HOST_AND_DEVICE bool in_range(
int start,
int end, // inclusive
int val) {
return start <= val && val <= end;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct Indexer2D {
const int& size2_;
FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
return index1 * size2_ + index2;
}
};
struct Indexer3D {
const int& size2_;
const int& size3_;
FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
: size2_(size2), size3_(size3) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3) {
return (index1 * size2_ + index2) * size3_ + index3;
}
};
struct Indexer4D {
const int& size2_;
const int& size3_;
const int& size4_;
HOST_AND_DEVICE Indexer4D(
const int& size2,
const int& size3,
const int& size4)
: size2_(size2), size3_(size3), size4_(size4) {}
HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3,
int index4) {
return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
HOST_AND_DEVICE void ComputeGradientsElement(
int bTgt,
int t,
int u,
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
if (t >= T || u >= U) { // out of boundary.
if (gradients == logits && t < maxT && u < maxU) {
// gradients and logits are pointing to the same memory location
Indexer3D idxr3(maxT, maxU);
int idx_b_t_u_zero = idxr3(bTgt, t, u);
if (idx_b_t_u_zero != -1) {
int start = idx_b_t_u_zero * D;
for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) {
gradients[b_t_u_d] = 0;
}
}
}
return;
}
int costIdx = bTgt * maxT * maxU;
CAST_DTYPE cost = -(betas[costIdx]);
Indexer2D idxr2(maxU - 1);
int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u;
Indexer3D idxr3(maxT, maxU);
idx_b_t_u = idxr3(bTgt, t, u);
idx_b_t_up1 = idxr3(bTgt, t, u + 1);
idx_b_tp1_u = idxr3(bTgt, t + 1, u);
if (idx_b_t_u == -1) {
return;
}
if (isinf(cost) || isnan(cost)) {
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
gradients[b_t_u_d] = 0;
}
return;
}
CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u];
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
if (clamp > 0) {
auto g = CAST_DTYPE(gradients[b_t_u_d]);
gradients[b_t_u_d] = math::min(g, clamp);
gradients[b_t_u_d] = math::max(g, -clamp);
}
}
}
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_ROCM
#include <cmath>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/half.cuh>
namespace torchaudio {
namespace rnnt {
namespace math {
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y)
return x;
else
return y;
}
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y)
return y;
else
return x;
}
// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);
template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
if (y > x) {
return y + log1pf(expf(x - y));
} else {
return x + log1pf(expf(y - x));
}
}
} // namespace math
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_CUDA
#if defined(USE_CUDA)
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
......@@ -8,6 +8,14 @@
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif defined(USE_ROCM)
#define WARP_SIZE 64
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
......
......@@ -6,6 +6,10 @@
#include <cuda_runtime.h>
#endif // USE_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
......@@ -18,6 +22,10 @@ typedef struct Options {
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
#endif
#ifdef USE_ROCM
// the stream to launch kernels in when using GPU.
hipStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
......
......@@ -27,7 +27,7 @@ class DtypeWorkspace {
~DtypeWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
TORCH_CHECK_NE(options.device_, UNDEFINED);
CHECK_NE(options.device_, UNDEFINED);
return ComputeSizeForDenominators(options) +
ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
ComputeSizeForBetas(options);
......@@ -36,7 +36,7 @@ class DtypeWorkspace {
void Free();
void Reset(const Options& options, DTYPE* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
TORCH_CHECK_LE(needed_size, size);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
......@@ -98,7 +98,7 @@ class IntWorkspace {
void Reset(const Options& options, int* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
TORCH_CHECK_LE(needed_size, size);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
......@@ -109,11 +109,11 @@ class IntWorkspace {
}
int* GetPointerToAlphaCounters() const {
TORCH_CHECK_EQ(options_.device_, GPU);
CHECK_EQ(options_.device_, GPU);
return data_;
}
int* GetPointerToBetaCounters() const {
TORCH_CHECK_EQ(options_.device_, GPU);
CHECK_EQ(options_.device_, GPU);
return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
}
......@@ -131,10 +131,22 @@ class IntWorkspace {
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
#ifdef USE_ROCM
if (data_ != nullptr && options_.device_ == GPU) {
hipMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
hipMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_ROCM
}
static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
......@@ -145,7 +157,7 @@ class IntWorkspace {
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment