Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a21b08e3
"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6184d8a43357f3397c0848b5d0b716cf389d1f30"
Unverified
Commit
a21b08e3
authored
May 19, 2021
by
Caroline Chen
Committed by
GitHub
May 19, 2021
Browse files
Remove unused RNNTL functions (#1518)
parent
af7eb4d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
167 deletions
+0
-167
torchaudio/csrc/rnnt/transducer.h
torchaudio/csrc/rnnt/transducer.h
+0
-121
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+0
-46
No files found.
torchaudio/csrc/rnnt/transducer.h
deleted
100644 → 0
View file @
af7eb4d6
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
switch
(
workspace
.
GetOptions
().
device_
)
{
case
CPU
:
{
status_t
status
=
cpu
::
Compute
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
costs
,
/*gradients=*/
gradients
);
return
status
;
}
case
GPU
:
{
status_t
status
=
gpu
::
Compute
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
costs
,
/*gradients=*/
gradients
);
return
status
;
}
default:
{
return
FAILURE
;
}
};
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
switch
(
workspace
.
GetOptions
().
device_
)
{
case
CPU
:
{
status_t
status
=
cpu
::
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alphas=*/
alphas
);
return
status
;
}
case
GPU
:
{
status_t
status
=
gpu
::
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
alphas
);
return
status
;
}
default:
{
return
FAILURE
;
}
};
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
switch
(
workspace
.
GetOptions
().
device_
)
{
case
CPU
:
{
status_t
status
=
cpu
::
ComputeBetas
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
costs
,
/*betas=*/
betas
);
return
status
;
}
case
GPU
:
{
status_t
status
=
gpu
::
ComputeBetas
<
DTYPE
,
CAST_DTYPE
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
costs
,
/*betas=*/
betas
);
return
status
;
}
default:
{
return
FAILURE
;
}
};
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/prototype/rnnt_loss.py
View file @
a21b08e3
...
...
@@ -7,52 +7,6 @@ __all__ = [
]
def
_rnnt_loss_alphas
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
=-
1
,
clamp
=-
1
,
):
"""
Compute alphas for RNN transducer loss.
See documentation for RNNTLoss
"""
return
torch
.
ops
.
torchaudio
.
rnnt_loss_alphas
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
)
def
_rnnt_loss_betas
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
=-
1
,
clamp
=-
1
,
):
"""
Compute betas for RNN transducer loss
See documentation for RNNTLoss
"""
return
torch
.
ops
.
torchaudio
.
rnnt_loss_betas
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
)
def
rnnt_loss
(
logits
:
Tensor
,
targets
:
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment