Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
d74d0604
Unverified
Commit
d74d0604
authored
Aug 03, 2021
by
Caroline Chen
Committed by
GitHub
Aug 03, 2021
Browse files
Remove fused_log_softmax option from RNNT Loss (#1615)
parent
9078c0b9
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
32 additions
and
135 deletions
+32
-135
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+0
-1
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+0
-16
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+0
-1
torchaudio/csrc/rnnt/autograd.cpp
torchaudio/csrc/rnnt/autograd.cpp
+5
-19
torchaudio/csrc/rnnt/compute.cpp
torchaudio/csrc/rnnt/compute.cpp
+3
-12
torchaudio/csrc/rnnt/compute.h
torchaudio/csrc/rnnt/compute.h
+1
-2
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+1
-3
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+1
-3
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
+3
-16
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
+2
-6
torchaudio/csrc/rnnt/gpu/kernels.h
torchaudio/csrc/rnnt/gpu/kernels.h
+15
-38
torchaudio/csrc/rnnt/options.h
torchaudio/csrc/rnnt/options.h
+1
-8
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+0
-10
No files found.
test/torchaudio_unittest/rnnt/autograd_impl.py
View file @
d74d0604
...
@@ -71,7 +71,6 @@ class Autograd(TestBaseMixin):
...
@@ -71,7 +71,6 @@ class Autograd(TestBaseMixin):
data
[
"target_lengths"
],
# target_lengths
data
[
"target_lengths"
],
# target_lengths
data
[
"blank"
],
# blank
data
[
"blank"
],
# blank
-
1
,
# clamp
-
1
,
# clamp
True
,
# fused_log_softmax
)
)
self
.
assert_grad
(
rnnt_loss
,
inputs
,
enable_all_grad
=
False
)
self
.
assert_grad
(
rnnt_loss
,
inputs
,
enable_all_grad
=
False
)
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
d74d0604
...
@@ -5,7 +5,6 @@ from .utils import (
...
@@ -5,7 +5,6 @@ from .utils import (
compute_with_numpy_transducer
,
compute_with_numpy_transducer
,
compute_with_pytorch_transducer
,
compute_with_pytorch_transducer
,
get_basic_data
,
get_basic_data
,
get_B1_T10_U3_D4_data
,
get_B1_T2_U3_D5_data
,
get_B1_T2_U3_D5_data
,
get_B2_T4_U3_D3_data
,
get_B2_T4_U3_D3_data
,
get_random_data
,
get_random_data
,
...
@@ -80,18 +79,3 @@ class RNNTLossTest:
...
@@ -80,18 +79,3 @@ class RNNTLossTest:
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
)
def
test_rnnt_nonfused_log_softmax
(
self
):
for
random
in
[
False
,
True
]:
data
=
get_B1_T10_U3_D4_data
(
random
=
random
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
[
"fused_log_softmax"
]
=
False
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
test/torchaudio_unittest/rnnt/utils.py
View file @
d74d0604
...
@@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data):
...
@@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data):
def
compute_with_pytorch_transducer
(
data
):
def
compute_with_pytorch_transducer
(
data
):
costs
=
RNNTLoss
(
costs
=
RNNTLoss
(
blank
=
data
[
"blank"
],
blank
=
data
[
"blank"
],
fused_log_softmax
=
data
.
get
(
"fused_log_softmax"
,
True
),
reduction
=
"none"
,
reduction
=
"none"
,
)(
)(
logits
=
data
[
"logits"
],
logits
=
data
[
"logits"
],
...
...
torchaudio/csrc/rnnt/autograd.cpp
View file @
d74d0604
...
@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
...
@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
)
{
bool
fused_log_softmax
=
true
)
{
torch
::
Tensor
undef
;
torch
::
Tensor
undef
;
auto
result
=
rnnt_loss
(
auto
result
=
logits
,
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
fused_log_softmax
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
ctx
->
save_for_backward
({
grads
});
ctx
->
save_for_backward
({
grads
});
...
@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
...
@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
)
{
bool
fused_log_softmax
=
true
)
{
at
::
AutoDispatchBelowADInplaceOrView
guard
;
at
::
AutoDispatchBelowADInplaceOrView
guard
;
auto
results
=
RNNTLossFunction
::
apply
(
auto
results
=
RNNTLossFunction
::
apply
(
logits
,
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
fused_log_softmax
);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
}
}
...
...
torchaudio/csrc/rnnt/compute.cpp
View file @
d74d0604
...
@@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
)
{
bool
fused_log_softmax
=
true
)
{
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
typed
<
decltype
(
rnnt_loss
)
>
();
.
typed
<
decltype
(
rnnt_loss
)
>
();
return
op
.
call
(
return
op
.
call
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
fused_log_softmax
);
}
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
...
@@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"Tensor target_lengths,"
"int blank,"
"int blank,"
"float clamp,"
"float clamp) -> (Tensor, Tensor?)"
);
"bool fused_log_softmax=True) -> (Tensor, Tensor?)"
);
}
}
torchaudio/csrc/rnnt/compute.h
View file @
d74d0604
...
@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
);
bool
fused_log_softmax
);
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
d74d0604
...
@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
)
{
bool
fused_log_softmax
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
...
@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
options
.
clamp_
=
clamp
;
options
.
fusedLogSmax_
=
fused_log_softmax
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
options
.
device_
=
CPU
;
options
.
device_
=
CPU
;
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
d74d0604
...
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
)
{
bool
fused_log_softmax
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
...
@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
options
.
clamp_
=
clamp
;
options
.
fusedLogSmax_
=
fused_log_softmax
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
View file @
d74d0604
...
@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
...
@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
const
int
*
tgtLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
,
CAST_DTYPE
*
logProbs
,
int
H
=
1
,
int
H
=
1
)
{
bool
fusedLogSmax
=
true
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
&
D
=
numTargets
;
...
@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
...
@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
if
(
!
fusedLogSmax
)
{
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
]);
}
if
(
u
<
U
-
1
)
{
if
(
u
<
U
-
1
)
{
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
// u).
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
if
(
!
fusedLogSmax
)
{
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
]);
}
}
}
}
}
...
@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
...
@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
DTYPE
*
gradients
,
int
H
=
1
,
int
H
=
1
)
{
bool
fusedLogSmax
=
true
)
{
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
const
int
u
=
blockIdx
.
y
;
...
@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
...
@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
alphas
,
alphas
,
betas
,
betas
,
gradients
,
gradients
,
H
,
H
);
fusedLogSmax
);
}
}
// This is a __global__ wrapper around ComputeAlphas
// This is a __global__ wrapper around ComputeAlphas
...
...
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
View file @
d74d0604
...
@@ -102,8 +102,6 @@ status_t Compute(
...
@@ -102,8 +102,6 @@ status_t Compute(
const
int
&
blank
=
options
.
blank_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
const
bool
&
fusedLogSmax
=
options
.
fusedLogSmax_
;
{
// compute denominators.
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*stream=*/
stream
,
...
@@ -134,8 +132,7 @@ status_t Compute(
...
@@ -134,8 +132,7 @@ status_t Compute(
/*tgtLengths=*/
tgtLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
,
H
);
fusedLogSmax
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
return
COMPUTE_LOG_PROBS_FAILED
;
...
@@ -200,8 +197,7 @@ status_t Compute(
...
@@ -200,8 +197,7 @@ status_t Compute(
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
,
/*gradients=*/
gradients
,
H
,
H
);
fusedLogSmax
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_GRADIENTS_FAILED
;
return
COMPUTE_GRADIENTS_FAILED
;
}
}
...
...
torchaudio/csrc/rnnt/gpu/kernels.h
View file @
d74d0604
...
@@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
...
@@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
DTYPE
*
gradients
,
int
H
=
1
,
int
H
=
1
)
{
bool
fusedLogSmax
=
true
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
&
D
=
numTargets
;
...
@@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
...
@@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
[
b_t_u_d
])
+
c
;
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
[
b_t_u_d
])
+
c
;
if
(
fusedLogSmax
)
{
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
])
-
std
::
exp
(
g
);
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
])
-
std
::
exp
(
g
);
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_tp1_u
!=
-
1
)
{
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_tp1_u
]);
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_tp1_u
]);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_t_up1
]);
}
}
else
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
}
}
}
else
{
// Non fused log softmax case
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
CAST_DTYPE
g
=
cost
+
CAST_DTYPE
(
logits
[
b_t_u_d
]);
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
];
gradients
[
b_t_u_d
]
=
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_t_up1
]);
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
]
+
betas
[
idx_b_tp1_u
];
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
]
+
betas
[
idx_b_t_up1
];
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
}
gradients
[
b_t_u_d
]
=
-
std
::
exp
(
gradients
[
b_t_u_d
]);
}
else
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
}
}
if
(
clamp
>
0
)
{
if
(
clamp
>
0
)
{
...
...
torchaudio/csrc/rnnt/options.h
View file @
d74d0604
...
@@ -42,12 +42,6 @@ typedef struct Options {
...
@@ -42,12 +42,6 @@ typedef struct Options {
// num_targets = D.
// num_targets = D.
int
numTargets_
;
int
numTargets_
;
// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool
fusedLogSmax_
;
Options
()
Options
()
:
device_
(
UNDEFINED
),
:
device_
(
UNDEFINED
),
numThreads_
(
0
),
numThreads_
(
0
),
...
@@ -58,8 +52,7 @@ typedef struct Options {
...
@@ -58,8 +52,7 @@ typedef struct Options {
nHypos_
(
1
),
nHypos_
(
1
),
maxSrcLen_
(
0
),
maxSrcLen_
(
0
),
maxTgtLen_
(
0
),
maxTgtLen_
(
0
),
numTargets_
(
0
),
numTargets_
(
0
)
{}
fusedLogSmax_
(
true
)
{}
int
BU
()
const
{
int
BU
()
const
{
return
batchSize_
*
maxTgtLen_
*
nHypos_
;
return
batchSize_
*
maxTgtLen_
*
nHypos_
;
...
...
torchaudio/prototype/rnnt_loss.py
View file @
d74d0604
...
@@ -14,7 +14,6 @@ def rnnt_loss(
...
@@ -14,7 +14,6 @@ def rnnt_loss(
target_lengths
:
Tensor
,
target_lengths
:
Tensor
,
blank
:
int
=
-
1
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
,
clamp
:
float
=
-
1
,
fused_log_softmax
:
bool
=
True
,
reduction
:
str
=
"mean"
,
reduction
:
str
=
"mean"
,
):
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
...
@@ -31,7 +30,6 @@ def rnnt_loss(
...
@@ -31,7 +30,6 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
...
@@ -42,9 +40,6 @@ def rnnt_loss(
...
@@ -42,9 +40,6 @@ def rnnt_loss(
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
"reduction should be one of 'none', 'mean', or 'sum'"
)
raise
ValueError
(
"reduction should be one of 'none', 'mean', or 'sum'"
)
if
not
fused_log_softmax
:
logits
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
if
blank
<
0
:
# reinterpret blank index if blank < 0.
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
blank
=
logits
.
shape
[
-
1
]
+
blank
...
@@ -55,7 +50,6 @@ def rnnt_loss(
...
@@ -55,7 +50,6 @@ def rnnt_loss(
target_lengths
=
target_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
blank
=
blank
,
clamp
=
clamp
,
clamp
=
clamp
,
fused_log_softmax
=
fused_log_softmax
)
)
if
reduction
==
'mean'
:
if
reduction
==
'mean'
:
...
@@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
...
@@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
Args:
Args:
blank (int, opt): blank label (Default: ``-1``)
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
"""
...
@@ -86,13 +79,11 @@ class RNNTLoss(torch.nn.Module):
...
@@ -86,13 +79,11 @@ class RNNTLoss(torch.nn.Module):
self
,
self
,
blank
:
int
=
-
1
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1.
,
clamp
:
float
=
-
1.
,
fused_log_softmax
:
bool
=
True
,
reduction
:
str
=
"mean"
,
reduction
:
str
=
"mean"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
blank
=
blank
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
clamp
=
clamp
self
.
fused_log_softmax
=
fused_log_softmax
self
.
reduction
=
reduction
self
.
reduction
=
reduction
def
forward
(
def
forward
(
...
@@ -120,6 +111,5 @@ class RNNTLoss(torch.nn.Module):
...
@@ -120,6 +111,5 @@ class RNNTLoss(torch.nn.Module):
target_lengths
,
target_lengths
,
self
.
blank
,
self
.
blank
,
self
.
clamp
,
self
.
clamp
,
self
.
fused_log_softmax
,
self
.
reduction
self
.
reduction
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment