Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2376e9c9
Unverified
Commit
2376e9c9
authored
Jun 24, 2021
by
Caroline Chen
Committed by
GitHub
Jun 24, 2021
Browse files
Rename RNNT loss C++ parameters (#1602)
parent
6a8ecd98
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
109 additions
and
103 deletions
+109
-103
torchaudio/csrc/rnnt/autograd.cpp
torchaudio/csrc/rnnt/autograd.cpp
+12
-12
torchaudio/csrc/rnnt/compute.cpp
torchaudio/csrc/rnnt/compute.cpp
+9
-9
torchaudio/csrc/rnnt/compute.h
torchaudio/csrc/rnnt/compute.h
+3
-3
torchaudio/csrc/rnnt/compute_alphas.cpp
torchaudio/csrc/rnnt/compute_alphas.cpp
+2
-2
torchaudio/csrc/rnnt/compute_betas.cpp
torchaudio/csrc/rnnt/compute_betas.cpp
+2
-2
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+26
-23
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
+6
-6
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
+7
-7
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+26
-23
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
+6
-6
torchaudio/csrc/rnnt/gpu/compute_betas.cu
torchaudio/csrc/rnnt/gpu/compute_betas.cu
+7
-7
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+3
-3
No files found.
torchaudio/csrc/rnnt/autograd.cpp
View file @
2376e9c9
...
@@ -10,21 +10,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
...
@@ -10,21 +10,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_s
oft
max
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
torch
::
Tensor
undef
;
torch
::
Tensor
undef
;
auto
result
=
rnnt_loss
(
auto
result
=
rnnt_loss
(
logits
,
logits
,
targets
,
targets
,
src
_lengths
,
logit
_lengths
,
t
g
t_lengths
,
t
arge
t_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_smax
,
fused_log_s
oft
max
,
reuse_logits_for_grads
);
reuse_logits_for_grads
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
...
@@ -47,21 +47,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
...
@@ -47,21 +47,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss_autograd
(
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss_autograd
(
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_s
oft
max
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
at
::
AutoDispatchBelowADInplaceOrView
guard
;
at
::
AutoDispatchBelowADInplaceOrView
guard
;
auto
results
=
RNNTLossFunction
::
apply
(
auto
results
=
RNNTLossFunction
::
apply
(
logits
,
logits
,
targets
,
targets
,
src
_lengths
,
logit
_lengths
,
t
g
t_lengths
,
t
arge
t_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_smax
,
fused_log_s
oft
max
,
reuse_logits_for_grads
);
reuse_logits_for_grads
);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
}
}
...
...
torchaudio/csrc/rnnt/compute.cpp
View file @
2376e9c9
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_s
oft
max
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
...
@@ -16,11 +16,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -16,11 +16,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
return
op
.
call
(
return
op
.
call
(
logits
,
logits
,
targets
,
targets
,
src
_lengths
,
logit
_lengths
,
t
g
t_lengths
,
t
arge
t_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_smax
,
fused_log_s
oft
max
,
reuse_logits_for_grads
);
reuse_logits_for_grads
);
}
}
...
@@ -28,10 +28,10 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -28,10 +28,10 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
def
(
m
.
def
(
"rnnt_loss(Tensor logits,"
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor targets,"
"Tensor
src
_lengths,"
"Tensor
logit
_lengths,"
"Tensor t
g
t_lengths,"
"Tensor t
arge
t_lengths,"
"int blank,"
"int blank,"
"float clamp,"
"float clamp,"
"bool fused_log_smax=True,"
"bool fused_log_s
oft
max=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"
);
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"
);
}
}
torchaudio/csrc/rnnt/compute.h
View file @
2376e9c9
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
,
bool
fused_log_s
oft
max
,
bool
reuse_logits_for_grads
);
bool
reuse_logits_for_grads
);
torchaudio/csrc/rnnt/compute_alphas.cpp
View file @
2376e9c9
...
@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
def
(
m
.
def
(
"rnnt_loss_alphas(Tensor logits,"
"rnnt_loss_alphas(Tensor logits,"
"Tensor targets,"
"Tensor targets,"
"Tensor
src
_lengths,"
"Tensor
logit
_lengths,"
"Tensor t
g
t_lengths,"
"Tensor t
arge
t_lengths,"
"int blank,"
"int blank,"
"float clamp) -> Tensor"
);
"float clamp) -> Tensor"
);
}
}
torchaudio/csrc/rnnt/compute_betas.cpp
View file @
2376e9c9
...
@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
def
(
m
.
def
(
"rnnt_loss_betas(Tensor logits,"
"rnnt_loss_betas(Tensor logits,"
"Tensor targets,"
"Tensor targets,"
"Tensor
src
_lengths,"
"Tensor
logit
_lengths,"
"Tensor t
g
t_lengths,"
"Tensor t
arge
t_lengths,"
"int blank,"
"int blank,"
"float clamp) -> Tensor"
);
"float clamp) -> Tensor"
);
}
}
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
2376e9c9
...
@@ -9,20 +9,20 @@ namespace cpu {
...
@@ -9,20 +9,20 @@ namespace cpu {
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_s
oft
max
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
src
_lengths
.
device
().
type
(),
logits
.
device
().
type
()
==
logit
_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
t
g
t_lengths
.
device
().
type
(),
logits
.
device
().
type
()
==
t
arge
t_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
@@ -30,28 +30,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -30,28 +30,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type"
);
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
TORCH_CHECK
(
t
g
t_lengths
.
dtype
()
==
torch
::
kInt32
,
t
arge
t_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src
_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
logit
_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
t
g
t_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
t
arge
t_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
TORCH_CHECK
(
src
_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
logit
_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
TORCH_CHECK
(
t
g
t_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
t
arge
t_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
targets
.
size
(
0
)
==
logits
.
size
(
0
),
...
@@ -62,24 +65,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -62,24 +65,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])"
);
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src
_lengths
).
item
().
toInt
(),
logits
.
size
(
1
)
==
at
::
max
(
logit
_lengths
).
item
().
toInt
(),
"input length mismatch"
);
"input length mismatch"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
t
g
t_lengths
).
item
().
toInt
()
+
1
,
logits
.
size
(
2
)
==
at
::
max
(
t
arge
t_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
"output length mismatch"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
t
g
t_lengths
).
item
().
toInt
(),
targets
.
size
(
1
)
==
at
::
max
(
t
arge
t_lengths
).
item
().
toInt
(),
"target length mismatch"
);
"target length mismatch"
);
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
options
.
clamp_
=
clamp
;
options
.
fusedLogSmax_
=
fused_log_smax
;
options
.
fusedLogSmax_
=
fused_log_s
oft
max
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
options
.
device_
=
CPU
;
options
.
device_
=
CPU
;
...
@@ -121,8 +124,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -121,8 +124,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
...
@@ -133,8 +136,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -133,8 +136,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
(
gradients
==
c10
::
nullopt
)
?
nullptr
...
...
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
View file @
2376e9c9
...
@@ -8,13 +8,13 @@ namespace cpu {
...
@@ -8,13 +8,13 @@ namespace cpu {
torch
::
Tensor
compute_alphas
(
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
)
{
double
clamp
)
{
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
...
@@ -55,8 +55,8 @@ torch::Tensor compute_alphas(
...
@@ -55,8 +55,8 @@ torch::Tensor compute_alphas(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
return
alphas
;
}
}
...
...
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
View file @
2376e9c9
...
@@ -8,13 +8,13 @@ namespace cpu {
...
@@ -8,13 +8,13 @@ namespace cpu {
torch
::
Tensor
compute_betas
(
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
)
{
double
clamp
)
{
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
...
@@ -25,7 +25,7 @@ torch::Tensor compute_betas(
...
@@ -25,7 +25,7 @@ torch::Tensor compute_betas(
options
.
device_
=
CPU
;
options
.
device_
=
CPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
torch
::
Tensor
costs
=
torch
::
empty
(
t
g
t_lengths
.
size
(
0
),
t
arge
t_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
torch
::
Tensor
betas
=
torch
::
zeros
(
...
@@ -59,8 +59,8 @@ torch::Tensor compute_betas(
...
@@ -59,8 +59,8 @@ torch::Tensor compute_betas(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
return
betas
;
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
2376e9c9
...
@@ -10,20 +10,20 @@ namespace gpu {
...
@@ -10,20 +10,20 @@ namespace gpu {
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_s
oft
max
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
src
_lengths
.
device
().
type
(),
logits
.
device
().
type
()
==
logit
_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
t
g
t_lengths
.
device
().
type
(),
logits
.
device
().
type
()
==
t
arge
t_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
@@ -31,28 +31,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -31,28 +31,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type"
);
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
TORCH_CHECK
(
t
g
t_lengths
.
dtype
()
==
torch
::
kInt32
,
t
arge
t_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src
_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
logit
_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
t
g
t_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
t
arge
t_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
TORCH_CHECK
(
src
_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
logit
_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
TORCH_CHECK
(
t
g
t_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
t
arge
t_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
targets
.
size
(
0
)
==
logits
.
size
(
0
),
...
@@ -63,24 +66,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -63,24 +66,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])"
);
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src
_lengths
).
item
().
toInt
(),
logits
.
size
(
1
)
==
at
::
max
(
logit
_lengths
).
item
().
toInt
(),
"input length mismatch"
);
"input length mismatch"
);
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
t
g
t_lengths
).
item
().
toInt
()
+
1
,
logits
.
size
(
2
)
==
at
::
max
(
t
arge
t_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
"output length mismatch"
);
TORCH_CHECK
(
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
t
g
t_lengths
).
item
().
toInt
(),
targets
.
size
(
1
)
==
at
::
max
(
t
arge
t_lengths
).
item
().
toInt
(),
"target length mismatch"
);
"target length mismatch"
);
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
options
.
clamp_
=
clamp
;
options
.
fusedLogSmax_
=
fused_log_smax
;
options
.
fusedLogSmax_
=
fused_log_s
oft
max
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -124,8 +127,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -124,8 +127,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
...
@@ -136,8 +139,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -136,8 +139,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
(
gradients
==
c10
::
nullopt
)
?
nullptr
...
...
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
View file @
2376e9c9
...
@@ -9,13 +9,13 @@ namespace gpu {
...
@@ -9,13 +9,13 @@ namespace gpu {
torch
::
Tensor
compute_alphas
(
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
)
{
double
clamp
)
{
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
...
@@ -58,8 +58,8 @@ torch::Tensor compute_alphas(
...
@@ -58,8 +58,8 @@ torch::Tensor compute_alphas(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
return
alphas
;
}
}
...
...
torchaudio/csrc/rnnt/gpu/compute_betas.cu
View file @
2376e9c9
...
@@ -9,13 +9,13 @@ namespace gpu {
...
@@ -9,13 +9,13 @@ namespace gpu {
torch
::
Tensor
compute_betas
(
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src
_lengths
,
const
torch
::
Tensor
&
logit
_lengths
,
const
torch
::
Tensor
&
t
g
t_lengths
,
const
torch
::
Tensor
&
t
arge
t_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
)
{
double
clamp
)
{
Options
options
;
Options
options
;
options
.
batchSize_
=
src
_lengths
.
size
(
0
);
options
.
batchSize_
=
logit
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
g
t_lengths
.
size
(
0
)
/
src
_lengths
.
size
(
0
);
options
.
nHypos_
=
t
arge
t_lengths
.
size
(
0
)
/
logit
_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
numTargets_
=
logits
.
size
(
3
);
...
@@ -28,7 +28,7 @@ torch::Tensor compute_betas(
...
@@ -28,7 +28,7 @@ torch::Tensor compute_betas(
options
.
device_
=
GPU
;
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
torch
::
Tensor
costs
=
torch
::
empty
(
t
g
t_lengths
.
size
(
0
),
t
arge
t_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
torch
::
Tensor
betas
=
torch
::
zeros
(
...
@@ -62,8 +62,8 @@ torch::Tensor compute_betas(
...
@@ -62,8 +62,8 @@ torch::Tensor compute_betas(
/*workspace=*/
workspace
,
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*
src
_lengths=*/
src
_lengths
.
data_ptr
<
int
>
(),
/*
logit
_lengths=*/
logit
_lengths
.
data_ptr
<
int
>
(),
/*t
g
t_lengths=*/
t
g
t_lengths
.
data_ptr
<
int
>
(),
/*t
arge
t_lengths=*/
t
arge
t_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
return
betas
;
...
...
torchaudio/prototype/rnnt_loss.py
View file @
2376e9c9
...
@@ -51,11 +51,11 @@ def rnnt_loss(
...
@@ -51,11 +51,11 @@ def rnnt_loss(
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
logits
=
logits
,
targets
=
targets
,
targets
=
targets
,
src
_lengths
=
logit_lengths
,
logit
_lengths
=
logit_lengths
,
t
g
t_lengths
=
target_lengths
,
t
arge
t_lengths
=
target_lengths
,
blank
=
blank
,
blank
=
blank
,
clamp
=
clamp
,
clamp
=
clamp
,
fused_log_smax
=
fused_log_softmax
,
fused_log_s
oft
max
=
fused_log_softmax
,
reuse_logits_for_grads
=
reuse_logits_for_grads
,)
reuse_logits_for_grads
=
reuse_logits_for_grads
,)
return
costs
return
costs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment