Commit ca478823 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Enable log probs input for rnnt loss (#2798)

Summary:
Add `fused_log_softmax` argument (default/current behavior = True) to rnnt loss.

If setting it to `False`, call `log_softmax` on the logits prior to passing it in to the rnnt loss function.

The following should produce the same output:
```
rnnt_loss(logits, targets, logit_lengths, target_lengths, fused_log_softmax=True)
```

```
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
rnnt_loss(log_probs, targets, logit_lengths, target_lengths, fused_log_softmax=False)
```

testing -- unit tests + get same results on the conformer rnnt recipe

Pull Request resolved: https://github.com/pytorch/audio/pull/2798

Reviewed By: xiaohui-zhang

Differential Revision: D41083523

Pulled By: carolineechen

fbshipit-source-id: e15442ceed1f461bbf06b724aa0561ff8827ad61
parent 2d99fee2
......@@ -189,13 +189,19 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data):
fused_log_softmax = data.get("fused_log_softmax", True)
input = data["logits"]
if not fused_log_softmax:
input = torch.nn.functional.log_softmax(input, dim=-1)
costs = rnnt_loss(
logits=data["logits"],
logits=input,
logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"],
targets=data["targets"],
blank=data["blank"],
reduction="none",
fused_log_softmax=fused_log_softmax,
)
loss = torch.sum(costs)
......@@ -260,6 +266,7 @@ def get_B1_T10_U3_D4_data(
data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device)
data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device)
data["blank"] = 0
data["fused_log_softmax"] = False
return data
......@@ -552,6 +559,7 @@ def get_random_data(
max_U=32,
max_D=40,
blank=-1,
fused_log_softmax=True,
dtype=torch.float32,
device=CPU_DEVICE,
seed=None,
......@@ -591,6 +599,7 @@ def get_random_data(
"logit_lengths": logit_lengths,
"target_lengths": target_lengths,
"blank": blank,
"fused_log_softmax": fused_log_softmax,
}
......
......@@ -51,6 +51,7 @@ class Functional(TestBaseMixin):
def _test_costs_and_gradients(self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2):
logits_shape = data["logits"].shape
costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
......@@ -637,13 +638,25 @@ class Functional(TestBaseMixin):
rtol=rtol,
)
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
@parameterized.expand([(True,), (False,)])
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self, fused_log_softmax):
seed = 777
for i in range(5):
data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
data = rnnt_utils.get_random_data(
fused_log_softmax=fused_log_softmax, dtype=torch.float32, device=self.device, seed=(seed + i)
)
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)
def test_rnnt_loss_nonfused_softmax(self):
data = rnnt_utils.get_B1_T10_U3_D4_data()
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
)
def test_psd(self):
"""Verify the ``F.psd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum as the input,
......
......@@ -13,10 +13,17 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
double clamp,
bool fused_log_softmax = true) {
torch::Tensor undef;
auto result =
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
auto result = rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
......@@ -41,10 +48,17 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
double clamp,
bool fused_log_softmax = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits, targets, logit_lengths, target_lengths, blank, clamp);
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return std::make_tuple(results[0], results[1]);
}
......
......@@ -7,11 +7,19 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
double clamp,
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
......@@ -21,5 +29,6 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> (Tensor, Tensor?)");
"float clamp,"
"bool fused_log_softmax) -> (Tensor, Tensor?)");
}
......@@ -8,4 +8,5 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp);
double clamp,
bool fused_log_softmax);
......@@ -12,7 +12,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
double clamp,
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
......@@ -80,6 +81,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
......
......@@ -96,6 +96,7 @@ void ComputeLogProbsOneSequence(
const int& T = srcLen;
const int& U = tgtLen;
const int& blank = options.blank_;
const bool& fusedLogSmax = options.fusedLogSmax_;
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
......@@ -105,6 +106,13 @@ void ComputeLogProbsOneSequence(
}
logProbs({t, u}).skip() =
CAST_DTYPE(logits({t, u, blank})) - denom({t, u});
if (!fusedLogSmax) {
if (u < U - 1) {
logProbs({t, u}).emit() = CAST_DTYPE(logits({t, u, targets[u]}));
}
logProbs({t, u}).skip() = CAST_DTYPE(logits({t, u, blank}));
}
}
}
}
......@@ -311,9 +319,11 @@ void ComputeGradientsOneSequence(
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
const bool& fusedLogSmax = options.fusedLogSmax_;
CAST_DTYPE cost = -beta({0, 0});
if (fusedLogSmax) {
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
......@@ -325,7 +335,8 @@ void ComputeGradientsOneSequence(
CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
if (d == blank && t == T - 1 &&
u == U - 1) { // last blank transition.
gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) =
......@@ -346,6 +357,34 @@ void ComputeGradientsOneSequence(
}
}
}
} else {
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = cost + CAST_DTYPE(logits({t, u, d}));
if (d == blank && t == T - 1 &&
u == U - 1) { // last blank transition.
gradients({t, u, d}) = g + alpha({t, u});
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) = g + alpha({t, u}) + beta({t + 1, u});
} else if (u < U - 1 && d == targets[u]) {
gradients({t, u, d}) = g + alpha({t, u}) + beta({t, u + 1});
} else {
gradients({t, u, d}) = g + CAST_DTYPE(-INFINITY);
}
gradients({t, u, d}) = -(std::exp(gradients({t, u, d})));
if (clamp > 0) {
gradients({t, u, d}) =
math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
gradients({t, u, d}) =
math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
}
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
......
......@@ -13,7 +13,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
double clamp,
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
......@@ -81,6 +82,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
......
......@@ -23,7 +23,8 @@ __global__ void ComputeLogProbs(
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1) {
int H = 1,
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
......@@ -48,12 +49,22 @@ __global__ void ComputeLogProbs(
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}
if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
}
}
......@@ -319,7 +330,8 @@ __global__ void ComputeGradients(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
int H = 1,
bool fusedLogSmax = true) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
......@@ -341,7 +353,8 @@ __global__ void ComputeGradients(
alphas,
betas,
gradients,
H);
H,
fusedLogSmax);
}
// This is a __global__ wrapper around ComputeAlphas
......
......@@ -102,6 +102,8 @@ status_t Compute(
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
const bool& fusedLogSmax = options.fusedLogSmax_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
......@@ -132,7 +134,8 @@ status_t Compute(
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
H,
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
......@@ -197,7 +200,8 @@ status_t Compute(
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H);
H,
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
......
......@@ -26,7 +26,8 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
int H = 1,
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
......@@ -78,6 +79,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (fusedLogSmax) {
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
......@@ -95,6 +97,27 @@ HOST_AND_DEVICE void ComputeGradientsElement(
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
} else { // Non fused log softmax case
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]);
if (d == blank && t == T - 1 && u == U - 1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u];
} else if (t < T - 1 && d == blank) {
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]);
}
if (clamp > 0) {
auto g = CAST_DTYPE(gradients[b_t_u_d]);
......
......@@ -42,6 +42,12 @@ typedef struct Options {
// num_targets = D.
int numTargets_;
// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;
Options()
: device_(UNDEFINED),
numThreads_(0),
......@@ -52,7 +58,8 @@ typedef struct Options {
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0) {}
numTargets_(0),
fusedLogSmax_(true) {}
int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
......
......@@ -1838,6 +1838,7 @@ def rnnt_loss(
blank: int = -1,
clamp: float = -1,
reduction: str = "mean",
fused_log_softmax: bool = True,
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
:cite:`graves2012sequence`.
......@@ -1860,6 +1861,7 @@ def rnnt_loss(
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
fused_log_softmax (bool): set to False if calling log_softmax outside of loss (Default: ``True``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size `(batch)`,
otherwise scalar.
......@@ -1877,6 +1879,7 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax,
)
if reduction == "mean":
......
......@@ -1744,6 +1744,7 @@ class RNNTLoss(torch.nn.Module):
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
fused_log_softmax (bool): set to False if calling log_softmax outside of loss (Default: ``True``)
Example
>>> # Hypothetical values
......@@ -1768,11 +1769,13 @@ class RNNTLoss(torch.nn.Module):
blank: int = -1,
clamp: float = -1.0,
reduction: str = "mean",
fused_log_softmax: bool = True,
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.reduction = reduction
self.fused_log_softmax = fused_log_softmax
def forward(
self,
......@@ -1792,4 +1795,13 @@ class RNNTLoss(torch.nn.Module):
Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size (batch),
otherwise scalar.
"""
return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction)
return F.rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.reduction,
self.fused_log_softmax,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment