Unverified Commit d49e6e45 authored by moto's avatar moto Committed by GitHub
Browse files

Replace simple_ctc with Python greedy decoder (#1558)

parent 1b52e720
......@@ -2,6 +2,3 @@
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
ignore = dirty
[submodule "examples/libtorchaudio/simplectc"]
path = examples/libtorchaudio/simplectc
url = https://github.com/mthrok/ctcdecode
......@@ -14,6 +14,5 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_subdirectory(../.. libtorchaudio)
add_subdirectory(simplectc)
add_subdirectory(augmentation)
add_subdirectory(speech_recognition)
Subproject commit b1a30d7a65342012e0d2524d9bae1c5412b24a23
add_executable(transcribe transcribe.cpp)
add_executable(transcribe_list transcribe_list.cpp)
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
set_property(TARGET transcribe PROPERTY CXX_STANDARD 14)
set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14)
......@@ -12,7 +12,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq
import simple_ctc
from greedy_decoder import Decoder
_LG = logging.getLogger(__name__)
......@@ -77,17 +78,7 @@ class Encoder(torch.nn.Module):
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
result, _ = self.encoder(waveform)
return result
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
return result[0]
def _get_decoder():
......@@ -125,18 +116,7 @@ def _get_decoder():
"Q",
"Z",
]
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
return Decoder(labels)
def _load_fairseq_model(input_file, data_dir=None):
......
......@@ -6,8 +6,7 @@ import os
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
import simple_ctc
from greedy_decoder import Decoder
_LG = logging.getLogger(__name__)
......@@ -59,19 +58,8 @@ class Encoder(torch.nn.Module):
self.encoder = encoder
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
length = torch.tensor([waveform.shape[1]])
result, length = self.encoder(waveform, length)
return result
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
result, _ = self.encoder(waveform)
return result[0]
def _get_model(model_id):
......@@ -84,17 +72,7 @@ def _get_model(model_id):
def _get_decoder(labels):
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
return Decoder(labels)
def _main():
......
import torch
class Decoder(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.labels = labels
def forward(self, logits: torch.Tensor) -> str:
"""Given a sequence logits over labels, get the best path string
Args:
logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char in ['<s>', '<pad>']:
continue
if char == '|':
char = ' '
hypothesis += char
return hypothesis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment