Unverified Commit d74d0604 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Remove fused_log_softmax option from RNNT Loss (#1615)

parent 9078c0b9
......@@ -71,7 +71,6 @@ class Autograd(TestBaseMixin):
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
)
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
......
......@@ -5,7 +5,6 @@ from .utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T10_U3_D4_data,
get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data,
get_random_data,
......@@ -80,18 +79,3 @@ class RNNTLossTest:
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_rnnt_nonfused_log_softmax(self):
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
dtype=torch.float32,
device=self.device,
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
data=data
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
......@@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reduction="none",
)(
logits=data["logits"],
......
......@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto result =
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
......@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
logits, targets, logit_lengths, target_lengths, blank, clamp);
return std::make_tuple(results[0], results[1]);
}
......
......@@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
......@@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
"float clamp) -> (Tensor, Tensor?)");
}
......@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax);
double clamp);
......@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
......@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
......
......@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
......@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
......
......@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
......@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}
if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
}
}
......@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
......@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
alphas,
betas,
gradients,
H,
fusedLogSmax);
H);
}
// This is a __global__ wrapper around ComputeAlphas
......
......@@ -102,8 +102,6 @@ status_t Compute(
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
const bool& fusedLogSmax = options.fusedLogSmax_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
......@@ -134,8 +132,7 @@ status_t Compute(
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H,
fusedLogSmax);
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
......@@ -200,8 +197,7 @@ status_t Compute(
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H,
fusedLogSmax);
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
......
......@@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
......@@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (fusedLogSmax) {
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else { // Non fused log softmax case
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]);
if (d == blank && t == T - 1 && u == U - 1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u];
} else if (t < T - 1 && d == blank) {
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]);
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
if (clamp > 0) {
......
......@@ -42,12 +42,6 @@ typedef struct Options {
// num_targets = D.
int numTargets_;
// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;
Options()
: device_(UNDEFINED),
numThreads_(0),
......@@ -58,8 +52,7 @@ typedef struct Options {
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0),
fusedLogSmax_(true) {}
numTargets_(0) {}
int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
......
......@@ -14,7 +14,6 @@ def rnnt_loss(
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
......@@ -31,7 +30,6 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
......@@ -42,9 +40,6 @@ def rnnt_loss(
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
......@@ -55,7 +50,6 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax
)
if reduction == 'mean':
......@@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
......@@ -86,13 +79,11 @@ class RNNTLoss(torch.nn.Module):
self,
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reduction = reduction
def forward(
......@@ -120,6 +111,5 @@ class RNNTLoss(torch.nn.Module):
target_lengths,
self.blank,
self.clamp,
self.fused_log_softmax,
self.reduction
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment