Unverified Commit 0f603eb9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

RNNT loss resolve null gradient (#1707)

parent 4ea80c56
...@@ -27,6 +27,12 @@ class RNNTLossTest: ...@@ -27,6 +27,12 @@ class RNNTLossTest:
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths) loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward() loss.backward()
def test_basic_forward_no_grad(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
logits.requires_grad_(False)
rnnt_loss(logits, targets, logit_lengths, target_lengths)
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self): def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data( data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float32, dtype=torch.float32,
......
...@@ -87,10 +87,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -87,10 +87,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor costs = torch::empty( torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_, options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt; c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
if (logits.requires_grad()) {
gradients = torch::zeros_like(logits);
}
torch::Tensor int_workspace = torch::empty( torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options), IntWorkspace::ComputeSizeFromOptions(options),
...@@ -120,8 +117,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -120,8 +117,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*gradients=*/ /*gradients=*/gradients->data_ptr<float>());
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
break; break;
} }
case torch::ScalarType::Half: { case torch::ScalarType::Half: {
...@@ -132,9 +128,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -132,9 +128,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(), /*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/ /*gradients=*/gradients->data_ptr<c10::Half>());
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
break; break;
} }
default: { default: {
......
...@@ -90,10 +90,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -90,10 +90,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor costs = torch::empty( torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_, options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt; c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
if (logits.requires_grad()) {
gradients = torch::zeros_like(logits);
}
torch::Tensor int_workspace = torch::empty( torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options), IntWorkspace::ComputeSizeFromOptions(options),
...@@ -123,8 +120,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -123,8 +120,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*gradients=*/ /*gradients=*/gradients->data_ptr<float>());
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
break; break;
} }
case torch::ScalarType::Half: { case torch::ScalarType::Half: {
...@@ -135,9 +131,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -135,9 +131,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/logit_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(), /*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/ /*gradients=*/gradients->data_ptr<c10::Half>());
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
break; break;
} }
default: { default: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment