Unverified Commit 0f603eb9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

RNNT loss resolve null gradient (#1707)

parent 4ea80c56
......@@ -27,6 +27,12 @@ class RNNTLossTest:
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()
def test_basic_forward_no_grad(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
logits.requires_grad_(False)
rnnt_loss(logits, targets, logit_lengths, target_lengths)
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float32,
......
......@@ -87,10 +87,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
gradients = torch::zeros_like(logits);
}
c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
......@@ -120,8 +117,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
/*gradients=*/gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
......@@ -132,9 +128,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
/*gradients=*/gradients->data_ptr<c10::Half>());
break;
}
default: {
......
......@@ -90,10 +90,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
gradients = torch::zeros_like(logits);
}
c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
......@@ -123,8 +120,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
/*gradients=*/gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
......@@ -135,9 +131,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
/*gradients=*/gradients->data_ptr<c10::Half>());
break;
}
default: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment