"src/vscode:/vscode.git/clone" did not exist on "5780776c8a13456788089eb5c4a3939be0c2c779"
Unverified Commit d49e6e45 authored by moto's avatar moto Committed by GitHub
Browse files

Replace simple_ctc with Python greedy decoder (#1558)

parent 1b52e720
...@@ -2,6 +2,3 @@ ...@@ -2,6 +2,3 @@
path = third_party/kaldi/submodule path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi url = https://github.com/kaldi-asr/kaldi
ignore = dirty ignore = dirty
[submodule "examples/libtorchaudio/simplectc"]
path = examples/libtorchaudio/simplectc
url = https://github.com/mthrok/ctcdecode
...@@ -14,6 +14,5 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}") ...@@ -14,6 +14,5 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_subdirectory(../.. libtorchaudio) add_subdirectory(../.. libtorchaudio)
add_subdirectory(simplectc)
add_subdirectory(augmentation) add_subdirectory(augmentation)
add_subdirectory(speech_recognition) add_subdirectory(speech_recognition)
Subproject commit b1a30d7a65342012e0d2524d9bae1c5412b24a23
add_executable(transcribe transcribe.cpp) add_executable(transcribe transcribe.cpp)
add_executable(transcribe_list transcribe_list.cpp) add_executable(transcribe_list transcribe_list.cpp)
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}") target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}") target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
set_property(TARGET transcribe PROPERTY CXX_STANDARD 14) set_property(TARGET transcribe PROPERTY CXX_STANDARD 14)
set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14) set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14)
...@@ -12,7 +12,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile ...@@ -12,7 +12,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq import fairseq
import simple_ctc
from greedy_decoder import Decoder
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -77,17 +78,7 @@ class Encoder(torch.nn.Module): ...@@ -77,17 +78,7 @@ class Encoder(torch.nn.Module):
def forward(self, waveform: torch.Tensor) -> torch.Tensor: def forward(self, waveform: torch.Tensor) -> torch.Tensor:
result, _ = self.encoder(waveform) result, _ = self.encoder(waveform)
return result return result[0]
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
def _get_decoder(): def _get_decoder():
...@@ -125,18 +116,7 @@ def _get_decoder(): ...@@ -125,18 +116,7 @@ def _get_decoder():
"Q", "Q",
"Z", "Z",
] ]
return Decoder(labels)
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
def _load_fairseq_model(input_file, data_dir=None): def _load_fairseq_model(input_file, data_dir=None):
......
...@@ -6,8 +6,7 @@ import os ...@@ -6,8 +6,7 @@ import os
import torch import torch
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
import simple_ctc from greedy_decoder import Decoder
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -59,19 +58,8 @@ class Encoder(torch.nn.Module): ...@@ -59,19 +58,8 @@ class Encoder(torch.nn.Module):
self.encoder = encoder self.encoder = encoder
def forward(self, waveform: torch.Tensor) -> torch.Tensor: def forward(self, waveform: torch.Tensor) -> torch.Tensor:
length = torch.tensor([waveform.shape[1]]) result, _ = self.encoder(waveform)
result, length = self.encoder(waveform, length) return result[0]
return result
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
def _get_model(model_id): def _get_model(model_id):
...@@ -84,17 +72,7 @@ def _get_model(model_id): ...@@ -84,17 +72,7 @@ def _get_model(model_id):
def _get_decoder(labels): def _get_decoder(labels):
return Decoder( return Decoder(labels)
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
def _main(): def _main():
......
import torch
class Decoder(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.labels = labels
def forward(self, logits: torch.Tensor) -> str:
"""Given a sequence logits over labels, get the best path string
Args:
logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char in ['<s>', '<pad>']:
continue
if char == '|':
char = ' '
hypothesis += char
return hypothesis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment