Unverified Commit 9e362816 authored by dgenzel2's avatar dgenzel2 Committed by GitHub
Browse files
parent 6d9c04d8
...@@ -16,7 +16,6 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -16,7 +16,6 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_smax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
at::AutoNonVariableTypeMode g;
torch::Tensor undef; torch::Tensor undef;
auto result = rnnt_loss( auto result = rnnt_loss(
logits, logits,
...@@ -54,6 +53,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( ...@@ -54,6 +53,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_smax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply( auto results = RNNTLossFunction::apply(
logits, logits,
targets, targets,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment