Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
1f136671
Unverified
Commit
1f136671
authored
May 12, 2021
by
discort
Committed by
GitHub
May 11, 2021
Browse files
Add vanilla DeepSpeech model (#1399)
Co-authored-by:
Vincent Quenneville-Belair
<
vincentqb@gmail.com
>
parent
4b2de71f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
1 deletion
+120
-1
docs/source/models.rst
docs/source/models.rst
+8
-0
test/torchaudio_unittest/models_test.py
test/torchaudio_unittest/models_test.py
+18
-1
torchaudio/models/__init__.py
torchaudio/models/__init__.py
+2
-0
torchaudio/models/deepspeech.py
torchaudio/models/deepspeech.py
+92
-0
No files found.
docs/source/models.rst
View file @
1f136671
...
...
@@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward
:hidden:`DeepSpeech`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DeepSpeech
.. automethod:: forward
:hidden:`Wav2Letter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
test/torchaudio_unittest/models_test.py
View file @
1f136671
...
...
@@ -3,7 +3,7 @@ from collections import namedtuple
import
torch
from
parameterized
import
parameterized
from
torchaudio.models
import
ConvTasNet
,
Wav2Letter
,
WaveRNN
from
torchaudio.models
import
ConvTasNet
,
DeepSpeech
,
Wav2Letter
,
WaveRNN
from
torchaudio.models.wavernn
import
MelResNet
,
UpsampleNetwork
from
torchaudio_unittest
import
common_utils
...
...
@@ -174,3 +174,20 @@ class TestConvTasNet(common_utils.TorchaudioTestCase):
output
=
model
(
tensor
)
assert
output
.
shape
==
(
batch_size
,
num_sources
,
num_frames
)
class
TestDeepSpeech
(
common_utils
.
TorchaudioTestCase
):
def
test_deepspeech
(
self
):
n_batch
=
2
n_feature
=
1
n_channel
=
1
n_class
=
40
n_time
=
320
model
=
DeepSpeech
(
n_feature
=
n_feature
,
n_class
=
n_class
)
x
=
torch
.
rand
(
n_batch
,
n_channel
,
n_time
,
n_feature
)
out
=
model
(
x
)
assert
out
.
size
()
==
(
n_batch
,
n_time
,
n_class
)
torchaudio/models/__init__.py
View file @
1f136671
from
.wav2letter
import
Wav2Letter
from
.wavernn
import
WaveRNN
from
.conv_tasnet
import
ConvTasNet
from
.deepspeech
import
DeepSpeech
__all__
=
[
'Wav2Letter'
,
'WaveRNN'
,
'ConvTasNet'
,
'DeepSpeech'
,
]
torchaudio/models/deepspeech.py
0 → 100644
View file @
1f136671
import
torch
__all__
=
[
"DeepSpeech"
]
class
FullyConnected
(
torch
.
nn
.
Module
):
"""
Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
"""
def
__init__
(
self
,
n_feature
:
int
,
n_hidden
:
int
,
dropout
:
float
,
relu_max_clip
:
int
=
20
)
->
None
:
super
(
FullyConnected
,
self
).
__init__
()
self
.
fc
=
torch
.
nn
.
Linear
(
n_feature
,
n_hidden
,
bias
=
True
)
self
.
relu_max_clip
=
relu_max_clip
self
.
dropout
=
dropout
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
fc
(
x
)
x
=
torch
.
nn
.
functional
.
relu
(
x
)
x
=
torch
.
nn
.
functional
.
hardtanh
(
x
,
0
,
self
.
relu_max_clip
)
if
self
.
dropout
:
x
=
torch
.
nn
.
functional
.
dropout
(
x
,
self
.
dropout
,
self
.
training
)
return
x
class
DeepSpeech
(
torch
.
nn
.
Module
):
"""
DeepSpeech model architecture from
`"Deep Speech: Scaling up end-to-end speech recognition"`
<https://arxiv.org/abs/1412.5567> paper.
Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
n_class: Number of output classes
"""
def
__init__
(
self
,
n_feature
:
int
,
n_hidden
:
int
=
2048
,
n_class
:
int
=
40
,
dropout
:
float
=
0.0
,
)
->
None
:
super
(
DeepSpeech
,
self
).
__init__
()
self
.
n_hidden
=
n_hidden
self
.
fc1
=
FullyConnected
(
n_feature
,
n_hidden
,
dropout
)
self
.
fc2
=
FullyConnected
(
n_hidden
,
n_hidden
,
dropout
)
self
.
fc3
=
FullyConnected
(
n_hidden
,
n_hidden
,
dropout
)
self
.
bi_rnn
=
torch
.
nn
.
RNN
(
n_hidden
,
n_hidden
,
num_layers
=
1
,
nonlinearity
=
"relu"
,
bidirectional
=
True
)
self
.
fc4
=
FullyConnected
(
n_hidden
,
n_hidden
,
dropout
)
self
.
out
=
torch
.
nn
.
Linear
(
n_hidden
,
n_class
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
Returns:
Tensor: Predictor tensor of dimension (batch, time, class).
"""
# N x C x T x F
x
=
self
.
fc1
(
x
)
# N x C x T x H
x
=
self
.
fc2
(
x
)
# N x C x T x H
x
=
self
.
fc3
(
x
)
# N x C x T x H
x
=
x
.
squeeze
(
1
)
# N x T x H
x
=
x
.
transpose
(
0
,
1
)
# T x N x H
x
,
_
=
self
.
bi_rnn
(
x
)
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
x
=
x
[:,
:,
:
self
.
n_hidden
]
+
x
[:,
:,
self
.
n_hidden
:]
# T x N x H
x
=
self
.
fc4
(
x
)
# T x N x H
x
=
self
.
out
(
x
)
# T x N x n_class
x
=
x
.
permute
(
1
,
0
,
2
)
# N x T x n_class
x
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=
2
)
# N x T x n_class
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment