Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0f603eb9
"vscode:/vscode.git/clone" did not exist on "52e2943cb7dc73ef35eb618fa00bcf2b165b7ece"
Unverified
Commit
0f603eb9
authored
Aug 17, 2021
by
Caroline Chen
Committed by
GitHub
Aug 17, 2021
Browse files
RNNT loss resolve null gradient (#1707)
parent
4ea80c56
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
18 deletions
+12
-18
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+6
-0
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+3
-9
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+3
-9
No files found.
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
0f603eb9
...
@@ -27,6 +27,12 @@ class RNNTLossTest:
...
@@ -27,6 +27,12 @@ class RNNTLossTest:
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
.
backward
()
loss
.
backward
()
def
test_basic_forward_no_grad
(
self
):
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic_data
(
self
.
device
)
logits
.
requires_grad_
(
False
)
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
...
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
0f603eb9
...
@@ -87,10 +87,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -87,10 +87,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch
::
Tensor
costs
=
torch
::
empty
(
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
if
(
logits
.
requires_grad
())
{
gradients
=
torch
::
zeros_like
(
logits
);
}
torch
::
Tensor
int_workspace
=
torch
::
empty
(
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
...
@@ -120,8 +117,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -120,8 +117,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
break
;
break
;
}
}
case
torch
::
ScalarType
::
Half
:
{
case
torch
::
ScalarType
::
Half
:
{
...
@@ -132,9 +128,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -132,9 +128,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
break
;
}
}
default:
{
default:
{
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
0f603eb9
...
@@ -90,10 +90,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -90,10 +90,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch
::
Tensor
costs
=
torch
::
empty
(
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
if
(
logits
.
requires_grad
())
{
gradients
=
torch
::
zeros_like
(
logits
);
}
torch
::
Tensor
int_workspace
=
torch
::
empty
(
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
...
@@ -123,8 +120,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -123,8 +120,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
break
;
break
;
}
}
case
torch
::
ScalarType
::
Half
:
{
case
torch
::
ScalarType
::
Half
:
{
...
@@ -135,9 +131,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -135,9 +131,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
break
;
}
}
default:
{
default:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment