Unverified Commit 1f136671 authored by discort's avatar discort Committed by GitHub
Browse files

Add vanilla DeepSpeech model (#1399)


Co-authored-by: default avatarVincent Quenneville-Belair <vincentqb@gmail.com>
parent 4b2de71f
...@@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio ...@@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward .. automethod:: forward
:hidden:`DeepSpeech`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DeepSpeech
.. automethod:: forward
:hidden:`Wav2Letter` :hidden:`Wav2Letter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -3,7 +3,7 @@ from collections import namedtuple ...@@ -3,7 +3,7 @@ from collections import namedtuple
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -174,3 +174,20 @@ class TestConvTasNet(common_utils.TorchaudioTestCase): ...@@ -174,3 +174,20 @@ class TestConvTasNet(common_utils.TorchaudioTestCase):
output = model(tensor) output = model(tensor)
assert output.shape == (batch_size, num_sources, num_frames) assert output.shape == (batch_size, num_sources, num_frames)
class TestDeepSpeech(common_utils.TorchaudioTestCase):
def test_deepspeech(self):
n_batch = 2
n_feature = 1
n_channel = 1
n_class = 40
n_time = 320
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
x = torch.rand(n_batch, n_channel, n_time, n_feature)
out = model(x)
assert out.size() == (n_batch, n_time, n_class)
from .wav2letter import Wav2Letter from .wav2letter import Wav2Letter
from .wavernn import WaveRNN from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
__all__ = [ __all__ = [
'Wav2Letter', 'Wav2Letter',
'WaveRNN', 'WaveRNN',
'ConvTasNet', 'ConvTasNet',
'DeepSpeech',
] ]
import torch
__all__ = ["DeepSpeech"]
class FullyConnected(torch.nn.Module):
"""
Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
"""
def __init__(self,
n_feature: int,
n_hidden: int,
dropout: float,
relu_max_clip: int = 20) -> None:
super(FullyConnected, self).__init__()
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
self.relu_max_clip = relu_max_clip
self.dropout = dropout
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
if self.dropout:
x = torch.nn.functional.dropout(x, self.dropout, self.training)
return x
class DeepSpeech(torch.nn.Module):
"""
DeepSpeech model architecture from
`"Deep Speech: Scaling up end-to-end speech recognition"`
<https://arxiv.org/abs/1412.5567> paper.
Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
n_class: Number of output classes
"""
def __init__(
self,
n_feature: int,
n_hidden: int = 2048,
n_class: int = 40,
dropout: float = 0.0,
) -> None:
super(DeepSpeech, self).__init__()
self.n_hidden = n_hidden
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
self.bi_rnn = torch.nn.RNN(
n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True
)
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
self.out = torch.nn.Linear(n_hidden, n_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
Returns:
Tensor: Predictor tensor of dimension (batch, time, class).
"""
# N x C x T x F
x = self.fc1(x)
# N x C x T x H
x = self.fc2(x)
# N x C x T x H
x = self.fc3(x)
# N x C x T x H
x = x.squeeze(1)
# N x T x H
x = x.transpose(0, 1)
# T x N x H
x, _ = self.bi_rnn(x)
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:]
# T x N x H
x = self.fc4(x)
# T x N x H
x = self.out(x)
# T x N x n_class
x = x.permute(1, 0, 2)
# N x T x n_class
x = torch.nn.functional.log_softmax(x, dim=2)
# N x T x n_class
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment