Unverified Commit 16f3b2f9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Remove reuse_logits_for_grads option for RNNTL (#1610)

parent 25ceee71
...@@ -53,7 +53,7 @@ class Autograd(TestBaseMixin): ...@@ -53,7 +53,7 @@ class Autograd(TestBaseMixin):
data["logit_lengths"], data["logit_lengths"],
data["target_lengths"], data["target_lengths"],
) )
loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False) loss = RNNTLoss(blank=data["blank"])
self.assert_grad(loss, inputs, enable_all_grad=False) self.assert_grad(loss, inputs, enable_all_grad=False)
...@@ -72,7 +72,6 @@ class Autograd(TestBaseMixin): ...@@ -72,7 +72,6 @@ class Autograd(TestBaseMixin):
data["blank"], # blank data["blank"], # blank
-1, # clamp -1, # clamp
True, # fused_log_softmax True, # fused_log_softmax
False, # reuse_logits_for_grads
) )
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
......
...@@ -17,14 +17,10 @@ class RNNTLossTest: ...@@ -17,14 +17,10 @@ class RNNTLossTest:
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
): ):
logits_shape = data["logits"].shape logits_shape = data["logits"].shape
for reuse_logits_for_grads in [False, True]: costs, gradients = compute_with_pytorch_transducer(data=data)
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads): self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
costs, gradients = compute_with_pytorch_transducer( self.assertEqual(logits_shape, gradients.shape)
data=data, reuse_logits_for_grads=reuse_logits_for_grads self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
def test_basic_backward(self): def test_basic_backward(self):
rnnt_loss = RNNTLoss() rnnt_loss = RNNTLoss()
......
...@@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data): ...@@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
return costs, gradients return costs, gradients
def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False): def compute_with_pytorch_transducer(data):
costs = RNNTLoss( costs = RNNTLoss(
blank=data["blank"], blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True), fused_log_softmax=data.get("fused_log_softmax", True),
reuse_logits_for_grads=reuse_logits_for_grads,
reduction="none", reduction="none",
)( )(
logits=data["logits"], logits=data["logits"],
......
...@@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax = true, bool fused_log_softmax = true) {
bool reuse_logits_for_grads = true) {
torch::Tensor undef; torch::Tensor undef;
auto result = rnnt_loss( auto result = rnnt_loss(
logits, logits,
...@@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
target_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_softmax, fused_log_softmax);
reuse_logits_for_grads);
auto costs = std::get<0>(result); auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef); auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads}); ctx->save_for_backward({grads});
...@@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( ...@@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax = true, bool fused_log_softmax = true) {
bool reuse_logits_for_grads = true) {
at::AutoDispatchBelowADInplaceOrView guard; at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply( auto results = RNNTLossFunction::apply(
logits, logits,
...@@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( ...@@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
target_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_softmax, fused_log_softmax);
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]); return std::make_tuple(results[0], results[1]);
} }
......
...@@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax = true, bool fused_log_softmax = true) {
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton() static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "") .findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>(); .typed<decltype(rnnt_loss)>();
...@@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
target_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_softmax, fused_log_softmax);
reuse_logits_for_grads);
} }
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...@@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor target_lengths," "Tensor target_lengths,"
"int blank," "int blank,"
"float clamp," "float clamp,"
"bool fused_log_softmax=True," "bool fused_log_softmax=True) -> (Tensor, Tensor?)");
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
} }
...@@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax, bool fused_log_softmax);
bool reuse_logits_for_grads);
...@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax = true, bool fused_log_softmax = true) {
bool reuse_logits_for_grads = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
...@@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt; c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) { if (logits.requires_grad()) {
if (reuse_logits_for_grads) { gradients = torch::zeros_like(logits);
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
} }
torch::Tensor int_workspace = torch::empty( torch::Tensor int_workspace = torch::empty(
......
...@@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_softmax = true, bool fused_log_softmax = true) {
bool reuse_logits_for_grads = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
...@@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt; c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) { if (logits.requires_grad()) {
if (reuse_logits_for_grads) { gradients = torch::zeros_like(logits);
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
} }
torch::Tensor int_workspace = torch::empty( torch::Tensor int_workspace = torch::empty(
......
...@@ -15,7 +15,6 @@ def rnnt_loss( ...@@ -15,7 +15,6 @@ def rnnt_loss(
blank: int = -1, blank: int = -1,
clamp: float = -1, clamp: float = -1,
fused_log_softmax: bool = True, fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean", reduction: str = "mean",
): ):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
...@@ -33,7 +32,6 @@ def rnnt_loss( ...@@ -33,7 +32,6 @@ def rnnt_loss(
blank (int, opt): blank label (Default: ``-1``) blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
...@@ -46,9 +44,6 @@ def rnnt_loss( ...@@ -46,9 +44,6 @@ def rnnt_loss(
if not fused_log_softmax: if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1) logits = torch.nn.functional.log_softmax(logits, dim=-1)
reuse_logits_for_grads = (
False # softmax needs the original logits value
)
if blank < 0: # reinterpret blank index if blank < 0. if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank blank = logits.shape[-1] + blank
...@@ -60,8 +55,8 @@ def rnnt_loss( ...@@ -60,8 +55,8 @@ def rnnt_loss(
target_lengths=target_lengths, target_lengths=target_lengths,
blank=blank, blank=blank,
clamp=clamp, clamp=clamp,
fused_log_softmax=fused_log_softmax, fused_log_softmax=fused_log_softmax
reuse_logits_for_grads=reuse_logits_for_grads,) )
if reduction == 'mean': if reduction == 'mean':
return costs.mean() return costs.mean()
...@@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module): ...@@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
blank (int, opt): blank label (Default: ``-1``) blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
""" """
...@@ -93,14 +87,12 @@ class RNNTLoss(torch.nn.Module): ...@@ -93,14 +87,12 @@ class RNNTLoss(torch.nn.Module):
blank: int = -1, blank: int = -1,
clamp: float = -1., clamp: float = -1.,
fused_log_softmax: bool = True, fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean", reduction: str = "mean",
): ):
super().__init__() super().__init__()
self.blank = blank self.blank = blank
self.clamp = clamp self.clamp = clamp
self.fused_log_softmax = fused_log_softmax self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads
self.reduction = reduction self.reduction = reduction
def forward( def forward(
...@@ -129,6 +121,5 @@ class RNNTLoss(torch.nn.Module): ...@@ -129,6 +121,5 @@ class RNNTLoss(torch.nn.Module):
self.blank, self.blank,
self.clamp, self.clamp,
self.fused_log_softmax, self.fused_log_softmax,
self.reuse_logits_for_grads,
self.reduction self.reduction
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment