Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
d49e6e45
Unverified
Commit
d49e6e45
authored
Jul 27, 2021
by
moto
Committed by
GitHub
Jul 27, 2021
Browse files
Replace simple_ctc with Python greedy decoder (#1558)
parent
1b52e720
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
57 deletions
+38
-57
.gitmodules
.gitmodules
+0
-3
examples/libtorchaudio/CMakeLists.txt
examples/libtorchaudio/CMakeLists.txt
+0
-1
examples/libtorchaudio/simplectc
examples/libtorchaudio/simplectc
+0
-1
examples/libtorchaudio/speech_recognition/CMakeLists.txt
examples/libtorchaudio/speech_recognition/CMakeLists.txt
+2
-2
examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py
...chaudio/speech_recognition/build_pipeline_from_fairseq.py
+4
-24
examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py
...cognition/build_pipeline_from_huggingface_transformers.py
+4
-26
examples/libtorchaudio/speech_recognition/greedy_decoder.py
examples/libtorchaudio/speech_recognition/greedy_decoder.py
+28
-0
No files found.
.gitmodules
View file @
d49e6e45
...
...
@@ -2,6 +2,3 @@
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
ignore = dirty
[submodule "examples/libtorchaudio/simplectc"]
path = examples/libtorchaudio/simplectc
url = https://github.com/mthrok/ctcdecode
examples/libtorchaudio/CMakeLists.txt
View file @
d49e6e45
...
...
@@ -14,6 +14,5 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}")
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
add_subdirectory
(
../.. libtorchaudio
)
add_subdirectory
(
simplectc
)
add_subdirectory
(
augmentation
)
add_subdirectory
(
speech_recognition
)
simplectc
@
b1a30d7a
Compare
b1a30d7a
...
b1a30d7a
Subproject commit b1a30d7a65342012e0d2524d9bae1c5412b24a23
examples/libtorchaudio/speech_recognition/CMakeLists.txt
View file @
d49e6e45
add_executable
(
transcribe transcribe.cpp
)
add_executable
(
transcribe_list transcribe_list.cpp
)
target_link_libraries
(
transcribe
"
${
TORCH_LIBRARIES
}
"
"
${
TORCHAUDIO_LIBRARY
}
"
"
${
CTCDECODE_LIBRARY
}
"
)
target_link_libraries
(
transcribe_list
"
${
TORCH_LIBRARIES
}
"
"
${
TORCHAUDIO_LIBRARY
}
"
"
${
CTCDECODE_LIBRARY
}
"
)
target_link_libraries
(
transcribe
"
${
TORCH_LIBRARIES
}
"
"
${
TORCHAUDIO_LIBRARY
}
"
)
target_link_libraries
(
transcribe_list
"
${
TORCH_LIBRARIES
}
"
"
${
TORCHAUDIO_LIBRARY
}
"
)
set_property
(
TARGET transcribe PROPERTY CXX_STANDARD 14
)
set_property
(
TARGET transcribe_list PROPERTY CXX_STANDARD 14
)
examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py
View file @
d49e6e45
...
...
@@ -12,7 +12,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
import
torchaudio
from
torchaudio.models.wav2vec2.utils.import_fairseq
import
import_fairseq_model
import
fairseq
import
simple_ctc
from
greedy_decoder
import
Decoder
_LG
=
logging
.
getLogger
(
__name__
)
...
...
@@ -77,17 +78,7 @@ class Encoder(torch.nn.Module):
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
torch
.
Tensor
:
result
,
_
=
self
.
encoder
(
waveform
)
return
result
class
Decoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
decoder
:
torch
.
nn
.
Module
):
super
().
__init__
()
self
.
decoder
=
decoder
def
forward
(
self
,
emission
:
torch
.
Tensor
)
->
str
:
result
=
self
.
decoder
.
decode
(
emission
)
return
''
.
join
(
result
.
label_sequences
[
0
][
0
]).
replace
(
'|'
,
' '
)
return
result
[
0
]
def
_get_decoder
():
...
...
@@ -125,18 +116,7 @@ def _get_decoder():
"Q"
,
"Z"
,
]
return
Decoder
(
simple_ctc
.
BeamSearchDecoder
(
labels
,
cutoff_top_n
=
40
,
cutoff_prob
=
0.8
,
beam_size
=
100
,
num_processes
=
1
,
blank_id
=
0
,
is_nll
=
True
,
)
)
return
Decoder
(
labels
)
def
_load_fairseq_model
(
input_file
,
data_dir
=
None
):
...
...
examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py
View file @
d49e6e45
...
...
@@ -6,8 +6,7 @@ import os
import
torch
import
torchaudio
from
torchaudio.models.wav2vec2.utils.import_huggingface
import
import_huggingface_model
import
simple_ctc
from
greedy_decoder
import
Decoder
_LG
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,19 +58,8 @@ class Encoder(torch.nn.Module):
self
.
encoder
=
encoder
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
torch
.
Tensor
:
length
=
torch
.
tensor
([
waveform
.
shape
[
1
]])
result
,
length
=
self
.
encoder
(
waveform
,
length
)
return
result
class
Decoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
decoder
:
torch
.
nn
.
Module
):
super
().
__init__
()
self
.
decoder
=
decoder
def
forward
(
self
,
emission
:
torch
.
Tensor
)
->
str
:
result
=
self
.
decoder
.
decode
(
emission
)
return
''
.
join
(
result
.
label_sequences
[
0
][
0
]).
replace
(
'|'
,
' '
)
result
,
_
=
self
.
encoder
(
waveform
)
return
result
[
0
]
def
_get_model
(
model_id
):
...
...
@@ -84,17 +72,7 @@ def _get_model(model_id):
def
_get_decoder
(
labels
):
return
Decoder
(
simple_ctc
.
BeamSearchDecoder
(
labels
,
cutoff_top_n
=
40
,
cutoff_prob
=
0.8
,
beam_size
=
100
,
num_processes
=
1
,
blank_id
=
0
,
is_nll
=
True
,
)
)
return
Decoder
(
labels
)
def
_main
():
...
...
examples/libtorchaudio/speech_recognition/greedy_decoder.py
0 → 100644
View file @
d49e6e45
import
torch
class
Decoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
labels
):
super
().
__init__
()
self
.
labels
=
labels
def
forward
(
self
,
logits
:
torch
.
Tensor
)
->
str
:
"""Given a sequence logits over labels, get the best path string
Args:
logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
best_path
=
torch
.
argmax
(
logits
,
dim
=-
1
)
# [num_seq,]
best_path
=
torch
.
unique_consecutive
(
best_path
,
dim
=-
1
)
hypothesis
=
''
for
i
in
best_path
:
char
=
self
.
labels
[
i
]
if
char
in
[
'<s>'
,
'<pad>'
]:
continue
if
char
==
'|'
:
char
=
' '
hypothesis
+=
char
return
hypothesis
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment