Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
8e370559
Unverified
Commit
8e370559
authored
Sep 28, 2020
by
moto
Committed by
GitHub
Sep 28, 2020
Browse files
Add ConvTasNet model (#920)
parent
9c274228
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
390 additions
and
1 deletion
+390
-1
test/torchaudio_unittest/models_test.py
test/torchaudio_unittest/models_test.py
+66
-1
torchaudio/models/__init__.py
torchaudio/models/__init__.py
+1
-0
torchaudio/models/conv_tasnet.py
torchaudio/models/conv_tasnet.py
+323
-0
No files found.
test/torchaudio_unittest/models_test.py
View file @
8e370559
import
itertools
from
collections
import
namedtuple
import
torch
from
torchaudio.models
import
Wav2Letter
,
MelResNet
,
UpsampleNetwork
,
WaveRNN
from
torchaudio.models
import
(
Wav2Letter
,
MelResNet
,
UpsampleNetwork
,
WaveRNN
,
ConvTasNet
,
)
from
parameterized
import
parameterized
from
torchaudio_unittest
import
common_utils
...
...
@@ -115,3 +125,58 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
out
=
model
(
x
,
mels
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
),
n_classes
)
_ConvTasNetParams
=
namedtuple
(
'_ConvTasNetParams'
,
[
'enc_num_feats'
,
'enc_kernel_size'
,
'msk_num_feats'
,
'msk_num_hidden_feats'
,
'msk_kernel_size'
,
'msk_num_layers'
,
'msk_num_stacks'
,
]
)
class
TestConvTasNet
(
common_utils
.
TorchaudioTestCase
):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
2
,
3
],
[
_ConvTasNetParams
(
128
,
40
,
128
,
256
,
3
,
7
,
2
),
_ConvTasNetParams
(
256
,
40
,
128
,
256
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
128
,
256
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
128
,
256
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
128
,
512
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
128
,
512
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
256
,
256
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
256
,
512
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
256
,
512
,
3
,
7
,
2
),
_ConvTasNetParams
(
512
,
40
,
128
,
512
,
3
,
6
,
4
),
_ConvTasNetParams
(
512
,
40
,
128
,
512
,
3
,
4
,
6
),
_ConvTasNetParams
(
512
,
40
,
128
,
512
,
3
,
8
,
3
),
_ConvTasNetParams
(
512
,
32
,
128
,
512
,
3
,
8
,
3
),
_ConvTasNetParams
(
512
,
16
,
128
,
512
,
3
,
8
,
3
),
],
)))
def
test_paper_configuration
(
self
,
num_sources
,
model_params
):
"""ConvTasNet model works on the valid configurations in the paper"""
batch_size
=
32
num_frames
=
8000
model
=
ConvTasNet
(
num_sources
=
num_sources
,
enc_kernel_size
=
model_params
.
enc_kernel_size
,
enc_num_feats
=
model_params
.
enc_num_feats
,
msk_kernel_size
=
model_params
.
msk_kernel_size
,
msk_num_feats
=
model_params
.
msk_num_feats
,
msk_num_hidden_feats
=
model_params
.
msk_num_hidden_feats
,
msk_num_layers
=
model_params
.
msk_num_layers
,
msk_num_stacks
=
model_params
.
msk_num_stacks
,
)
tensor
=
torch
.
rand
(
batch_size
,
1
,
num_frames
)
output
=
model
(
tensor
)
assert
output
.
shape
==
(
batch_size
,
num_sources
,
num_frames
)
torchaudio/models/__init__.py
View file @
8e370559
from
.wav2letter
import
*
from
.wavernn
import
*
from
.conv_tasnet
import
ConvTasNet
torchaudio/models/conv_tasnet.py
0 → 100644
View file @
8e370559
"""Implements Conv-TasNet with building blocks of it.
Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
"""
from
typing
import
Tuple
,
Optional
import
torch
class
ConvBlock
(
torch
.
nn
.
Module
):
"""1D Convolutional block.
Args:
io_channels (int): The number of input/output channels, <B, Sc>
hidden_channels (int): The number of channels in the internal layers, <H>.
kernel_size (int): The convolution kernel size of the middle layer, <P>.
padding (int): Padding value of the convolution in the middle layer.
dilation (int): Dilation value of the convolution in the middle layer.
no_redisual (bool): Disable residual block/output.
Note:
This implementation corresponds to the "non-causal" setting in the paper.
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
def
__init__
(
self
,
io_channels
:
int
,
hidden_channels
:
int
,
kernel_size
:
int
,
padding
:
int
,
dilation
:
int
=
1
,
no_residual
:
bool
=
False
,
):
super
().
__init__
()
self
.
conv_layers
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
in_channels
=
io_channels
,
out_channels
=
hidden_channels
,
kernel_size
=
1
),
torch
.
nn
.
PReLU
(),
torch
.
nn
.
GroupNorm
(
num_groups
=
1
,
num_channels
=
hidden_channels
,
eps
=
1e-08
),
torch
.
nn
.
Conv1d
(
in_channels
=
hidden_channels
,
out_channels
=
hidden_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
hidden_channels
,
),
torch
.
nn
.
PReLU
(),
torch
.
nn
.
GroupNorm
(
num_groups
=
1
,
num_channels
=
hidden_channels
,
eps
=
1e-08
),
)
self
.
res_out
=
(
None
if
no_residual
else
torch
.
nn
.
Conv1d
(
in_channels
=
hidden_channels
,
out_channels
=
io_channels
,
kernel_size
=
1
)
)
self
.
skip_out
=
torch
.
nn
.
Conv1d
(
in_channels
=
hidden_channels
,
out_channels
=
io_channels
,
kernel_size
=
1
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
torch
.
Tensor
]:
feature
=
self
.
conv_layers
(
input
)
if
self
.
res_out
is
None
:
residual
=
None
else
:
residual
=
self
.
res_out
(
feature
)
skip_out
=
self
.
skip_out
(
feature
)
return
residual
,
skip_out
class
MaskGenerator
(
torch
.
nn
.
Module
):
"""TCN (Temporal Convolution Network) Separation Module
Generates masks for separation.
Args:
input_dim (int): Input feature dimension, <N>.
num_sources (int): The number of sources to separate.
kernel_size (int): The convolution kernel size of conv blocks, <P>.
num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
num_hidden (int): Intermediate feature dimention of conv blocks, <H>
num_layers (int): The number of conv blocks in one stack, <X>.
num_stacks (int): The number of conv block stacks, <R>.
Note:
This implementation corresponds to the "non-causal" setting in the paper.
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
def
__init__
(
self
,
input_dim
:
int
,
num_sources
:
int
,
kernel_size
:
int
,
num_feats
:
int
,
num_hidden
:
int
,
num_layers
:
int
,
num_stacks
:
int
,
):
super
().
__init__
()
self
.
input_dim
=
input_dim
self
.
num_sources
=
num_sources
self
.
input_norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
1
,
num_channels
=
input_dim
,
eps
=
1e-8
)
self
.
input_conv
=
torch
.
nn
.
Conv1d
(
in_channels
=
input_dim
,
out_channels
=
num_feats
,
kernel_size
=
1
)
self
.
receptive_field
=
0
self
.
conv_layers
=
torch
.
nn
.
ModuleList
([])
for
s
in
range
(
num_stacks
):
for
l
in
range
(
num_layers
):
multi
=
2
**
l
self
.
conv_layers
.
append
(
ConvBlock
(
io_channels
=
num_feats
,
hidden_channels
=
num_hidden
,
kernel_size
=
kernel_size
,
dilation
=
multi
,
padding
=
multi
,
# The last ConvBlock does not need residual
no_residual
=
(
l
==
(
num_layers
-
1
)
and
s
==
(
num_stacks
-
1
)),
)
)
self
.
receptive_field
+=
(
kernel_size
if
s
==
0
and
l
==
0
else
(
kernel_size
-
1
)
*
multi
)
self
.
output_prelu
=
torch
.
nn
.
PReLU
()
self
.
output_conv
=
torch
.
nn
.
Conv1d
(
in_channels
=
num_feats
,
out_channels
=
input_dim
*
num_sources
,
kernel_size
=
1
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Generate separation mask.
Args:
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
Returns:
torch.Tensor: shape [batch, num_sources, features, frames]
"""
batch_size
=
input
.
shape
[
0
]
feats
=
self
.
input_norm
(
input
)
feats
=
self
.
input_conv
(
feats
)
output
=
0.0
for
layer
in
self
.
conv_layers
:
residual
,
skip
=
layer
(
feats
)
if
residual
is
not
None
:
# the last conv layer does not produce residual
feats
=
feats
+
residual
output
=
output
+
skip
output
=
self
.
output_prelu
(
output
)
output
=
self
.
output_conv
(
output
)
output
=
torch
.
sigmoid
(
output
)
return
output
.
view
(
batch_size
,
self
.
num_sources
,
self
.
input_dim
,
-
1
)
class
ConvTasNet
(
torch
.
nn
.
Module
):
"""Conv-TasNet: a fully-convolutional time-domain audio separation network
Args:
num_sources (int): The number of sources to split.
enc_kernel_size (int): The convolution kernel size of the encoder/decoder, <L>.
enc_num_feats (int): The feature dimensions passed to mask generator, <N>.
msk_kernel_size (int): The convolution kernel size of the mask generator, <P>.
msk_num_feats (int): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
msk_num_hidden_feats (int): The internal feature dimension of conv block of the mask generator, <H>.
msk_num_layers (int): The number of layers in one conv block of the mask generator, <X>.
msk_num_stacks (int): The numbr of conv blocks of the mask generator, <R>.
Note:
This implementation corresponds to the "non-causal" setting in the paper.
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
def
__init__
(
self
,
num_sources
:
int
=
2
,
# encoder/decoder parameters
enc_kernel_size
:
int
=
16
,
enc_num_feats
:
int
=
512
,
# mask generator parameters
msk_kernel_size
:
int
=
3
,
msk_num_feats
:
int
=
128
,
msk_num_hidden_feats
:
int
=
512
,
msk_num_layers
:
int
=
8
,
msk_num_stacks
:
int
=
3
,
):
super
().
__init__
()
self
.
num_sources
=
num_sources
self
.
enc_num_feats
=
enc_num_feats
self
.
enc_kernel_size
=
enc_kernel_size
self
.
enc_stride
=
enc_kernel_size
//
2
self
.
encoder
=
torch
.
nn
.
Conv1d
(
in_channels
=
1
,
out_channels
=
enc_num_feats
,
kernel_size
=
enc_kernel_size
,
stride
=
self
.
enc_stride
,
padding
=
self
.
enc_stride
,
bias
=
False
,
)
self
.
mask_generator
=
MaskGenerator
(
input_dim
=
enc_num_feats
,
num_sources
=
num_sources
,
kernel_size
=
msk_kernel_size
,
num_feats
=
msk_num_feats
,
num_hidden
=
msk_num_hidden_feats
,
num_layers
=
msk_num_layers
,
num_stacks
=
msk_num_stacks
,
)
self
.
decoder
=
torch
.
nn
.
ConvTranspose1d
(
in_channels
=
enc_num_feats
,
out_channels
=
1
,
kernel_size
=
enc_kernel_size
,
stride
=
self
.
enc_stride
,
padding
=
self
.
enc_stride
,
bias
=
False
,
)
def
_align_num_frames_with_strides
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Pad input Tensor so that the end of the input tensor corresponds with
1. (if kernel size is odd) the center of the last convolution kernel
or 2. (if kernel size is even) the end of the first half of the last convolution kernel
Assumption:
The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
on the both ends in Conv1D
|<--- k_1 --->|
| | |<-- k_n-1 -->|
| | | |<--- k_n --->|
| | | | |
| | | | |
| v v v |
|<---->|<--- input signal --->|<--->|<---->|
stride PAD stride
Args:
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
Returns:
torch.Tensor: Padded Tensor
int: Number of paddings performed
"""
batch_size
,
num_channels
,
num_frames
=
input
.
shape
is_odd
=
self
.
enc_kernel_size
%
2
num_strides
=
(
num_frames
-
is_odd
)
//
self
.
enc_stride
num_remainings
=
num_frames
-
(
is_odd
+
num_strides
*
self
.
enc_stride
)
if
num_remainings
==
0
:
return
input
,
0
num_paddings
=
self
.
enc_stride
-
num_remainings
pad
=
torch
.
zeros
(
batch_size
,
num_channels
,
num_paddings
,
dtype
=
input
.
dtype
,
device
=
input
.
device
,
)
return
torch
.
cat
([
input
,
pad
],
2
),
num_paddings
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Perform source separation. Generate audio source waveforms.
Args:
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
Returns:
torch.Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
"""
if
input
.
ndim
!=
3
or
input
.
shape
[
1
]
!=
1
:
raise
ValueError
(
f
"Expected 3D tensor (batch, channel==1, frames). Found:
{
input
.
shape
}
"
)
# B: batch size
# L: input frame length
# L': padded input frame length
# F: feature dimension
# M: feature frame length
# S: number of sources
padded
,
num_pads
=
self
.
_align_num_frames_with_strides
(
input
)
# B, 1, L'
batch_size
,
num_padded_frames
=
padded
.
shape
[
0
],
padded
.
shape
[
2
]
feats
=
self
.
encoder
(
padded
)
# B, F, M
masked
=
self
.
mask_generator
(
feats
)
*
feats
.
unsqueeze
(
1
)
# B, S, F, M
masked
=
masked
.
view
(
batch_size
*
self
.
num_sources
,
self
.
enc_num_feats
,
-
1
)
# B*S, F, M
decoded
=
self
.
decoder
(
masked
)
# B*S, 1, L'
output
=
decoded
.
view
(
batch_size
,
self
.
num_sources
,
num_padded_frames
)
# B, S, L'
if
num_pads
>
0
:
output
=
output
[...,
:
-
num_pads
]
# B, S, L
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment