Unverified Commit d74d0604 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Remove fused_log_softmax option from RNNT Loss (#1615)

parent 9078c0b9
...@@ -71,7 +71,6 @@ class Autograd(TestBaseMixin): ...@@ -71,7 +71,6 @@ class Autograd(TestBaseMixin):
data["target_lengths"], # target_lengths data["target_lengths"], # target_lengths
data["blank"], # blank data["blank"], # blank
-1, # clamp -1, # clamp
True, # fused_log_softmax
) )
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
......
...@@ -5,7 +5,6 @@ from .utils import ( ...@@ -5,7 +5,6 @@ from .utils import (
compute_with_numpy_transducer, compute_with_numpy_transducer,
compute_with_pytorch_transducer, compute_with_pytorch_transducer,
get_basic_data, get_basic_data,
get_B1_T10_U3_D4_data,
get_B1_T2_U3_D5_data, get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data, get_B2_T4_U3_D3_data,
get_random_data, get_random_data,
...@@ -80,18 +79,3 @@ class RNNTLossTest: ...@@ -80,18 +79,3 @@ class RNNTLossTest:
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
) )
def test_rnnt_nonfused_log_softmax(self):
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
dtype=torch.float32,
device=self.device,
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
data=data
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
...@@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data): ...@@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data): def compute_with_pytorch_transducer(data):
costs = RNNTLoss( costs = RNNTLoss(
blank=data["blank"], blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reduction="none", reduction="none",
)( )(
logits=data["logits"], logits=data["logits"],
......
...@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp) {
bool fused_log_softmax = true) {
torch::Tensor undef; torch::Tensor undef;
auto result = rnnt_loss( auto result =
logits, rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto costs = std::get<0>(result); auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef); auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads}); ctx->save_for_backward({grads});
...@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( ...@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp) {
bool fused_log_softmax = true) {
at::AutoDispatchBelowADInplaceOrView guard; at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply( auto results = RNNTLossFunction::apply(
logits, logits, targets, logit_lengths, target_lengths, blank, clamp);
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return std::make_tuple(results[0], results[1]); return std::make_tuple(results[0], results[1]);
} }
......
...@@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp) {
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton() static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "") .findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>(); .typed<decltype(rnnt_loss)>();
return op.call( return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
} }
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...@@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths," "Tensor logit_lengths,"
"Tensor target_lengths," "Tensor target_lengths,"
"int blank," "int blank,"
"float clamp," "float clamp) -> (Tensor, Tensor?)");
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
} }
...@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp);
bool fused_log_softmax);
...@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp) {
bool fused_log_softmax = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
...@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
options.blank_ = blank; options.blank_ = blank;
options.clamp_ = clamp; options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU; options.device_ = CPU;
......
...@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp) {
bool fused_log_softmax = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
...@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
options.blank_ = blank; options.blank_ = blank;
options.clamp_ = clamp; options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream(); options.stream_ = at::cuda::getCurrentCUDAStream();
......
...@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs( ...@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
const int* tgtLengths, const int* tgtLengths,
const CAST_DTYPE* denominators, const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs, CAST_DTYPE* logProbs,
int H = 1, int H = 1) {
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen; const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen; const int& maxU = maxTgtLen;
const int& D = numTargets; const int& D = numTargets;
...@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs( ...@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}
if (u < U - 1) { if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u). // u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
} }
} }
...@@ -330,8 +319,7 @@ __global__ void ComputeGradients( ...@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
const CAST_DTYPE* alphas, const CAST_DTYPE* alphas,
const CAST_DTYPE* betas, const CAST_DTYPE* betas,
DTYPE* gradients, DTYPE* gradients,
int H = 1, int H = 1) {
bool fusedLogSmax = true) {
const int bTgt = blockIdx.z; // 0 <= b < B const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x; const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y; const int u = blockIdx.y;
...@@ -353,8 +341,7 @@ __global__ void ComputeGradients( ...@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
alphas, alphas,
betas, betas,
gradients, gradients,
H, H);
fusedLogSmax);
} }
// This is a __global__ wrapper around ComputeAlphas // This is a __global__ wrapper around ComputeAlphas
......
...@@ -102,8 +102,6 @@ status_t Compute( ...@@ -102,8 +102,6 @@ status_t Compute(
const int& blank = options.blank_; const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_; const CAST_DTYPE clamp = options.clamp_;
const bool& fusedLogSmax = options.fusedLogSmax_;
{ // compute denominators. { // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>( status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream, /*stream=*/stream,
...@@ -134,8 +132,7 @@ status_t Compute( ...@@ -134,8 +132,7 @@ status_t Compute(
/*tgtLengths=*/tgtLengths, /*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(), /*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(), /*log_probs=*/workspace.GetPointerToLogProbs(),
H, H);
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) { if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED; return COMPUTE_LOG_PROBS_FAILED;
...@@ -200,8 +197,7 @@ status_t Compute( ...@@ -200,8 +197,7 @@ status_t Compute(
/*alphas=*/workspace.GetPointerToAlphas(), /*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(), /*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients, /*gradients=*/gradients,
H, H);
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) { if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED; return COMPUTE_GRADIENTS_FAILED;
} }
......
...@@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement( ...@@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const CAST_DTYPE* alphas, const CAST_DTYPE* alphas,
const CAST_DTYPE* betas, const CAST_DTYPE* betas,
DTYPE* gradients, DTYPE* gradients,
int H = 1, int H = 1) {
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen; const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen; const int& maxU = maxTgtLen;
const int& D = numTargets; const int& D = numTargets;
...@@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement( ...@@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int b_t_u_d = idx_b_t_u * D + d; int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c; CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (fusedLogSmax) { if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); } else if (t < T - 1 && d == blank) {
} else if (t < T - 1 && d == blank) { gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); if (idx_b_tp1_u != -1) {
if (idx_b_tp1_u != -1) { gradients[b_t_u_d] =
gradients[b_t_u_d] = gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
} }
} else { // Non fused log softmax case } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]); gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (d == blank && t == T - 1 && u == U - 1) { if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u]; gradients[b_t_u_d] =
} else if (t < T - 1 && d == blank) { gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
} }
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]); } else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
} }
if (clamp > 0) { if (clamp > 0) {
......
...@@ -42,12 +42,6 @@ typedef struct Options { ...@@ -42,12 +42,6 @@ typedef struct Options {
// num_targets = D. // num_targets = D.
int numTargets_; int numTargets_;
// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;
Options() Options()
: device_(UNDEFINED), : device_(UNDEFINED),
numThreads_(0), numThreads_(0),
...@@ -58,8 +52,7 @@ typedef struct Options { ...@@ -58,8 +52,7 @@ typedef struct Options {
nHypos_(1), nHypos_(1),
maxSrcLen_(0), maxSrcLen_(0),
maxTgtLen_(0), maxTgtLen_(0),
numTargets_(0), numTargets_(0) {}
fusedLogSmax_(true) {}
int BU() const { int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_; return batchSize_ * maxTgtLen_ * nHypos_;
......
...@@ -14,7 +14,6 @@ def rnnt_loss( ...@@ -14,7 +14,6 @@ def rnnt_loss(
target_lengths: Tensor, target_lengths: Tensor,
blank: int = -1, blank: int = -1,
clamp: float = -1, clamp: float = -1,
fused_log_softmax: bool = True,
reduction: str = "mean", reduction: str = "mean",
): ):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
...@@ -31,7 +30,6 @@ def rnnt_loss( ...@@ -31,7 +30,6 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``) blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
...@@ -42,9 +40,6 @@ def rnnt_loss( ...@@ -42,9 +40,6 @@ def rnnt_loss(
if reduction not in ['none', 'mean', 'sum']: if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'") raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
if blank < 0: # reinterpret blank index if blank < 0. if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank blank = logits.shape[-1] + blank
...@@ -55,7 +50,6 @@ def rnnt_loss( ...@@ -55,7 +50,6 @@ def rnnt_loss(
target_lengths=target_lengths, target_lengths=target_lengths,
blank=blank, blank=blank,
clamp=clamp, clamp=clamp,
fused_log_softmax=fused_log_softmax
) )
if reduction == 'mean': if reduction == 'mean':
...@@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module): ...@@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
Args: Args:
blank (int, opt): blank label (Default: ``-1``) blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
""" """
...@@ -86,13 +79,11 @@ class RNNTLoss(torch.nn.Module): ...@@ -86,13 +79,11 @@ class RNNTLoss(torch.nn.Module):
self, self,
blank: int = -1, blank: int = -1,
clamp: float = -1., clamp: float = -1.,
fused_log_softmax: bool = True,
reduction: str = "mean", reduction: str = "mean",
): ):
super().__init__() super().__init__()
self.blank = blank self.blank = blank
self.clamp = clamp self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reduction = reduction self.reduction = reduction
def forward( def forward(
...@@ -120,6 +111,5 @@ class RNNTLoss(torch.nn.Module): ...@@ -120,6 +111,5 @@ class RNNTLoss(torch.nn.Module):
target_lengths, target_lengths,
self.blank, self.blank,
self.clamp, self.clamp,
self.fused_log_softmax,
self.reduction self.reduction
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment