Commit 6057d3cf authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add conv_tasnet_base factory function to prototype (#2411)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2411

Reviewed By: carolineechen

Differential Revision: D36663904

Pulled By: nateanl

fbshipit-source-id: c6a7dd530c9cfbb58b7121ebe02db6ae293cc2d0
parent 93024ace
......@@ -14,6 +14,11 @@ conformer_rnnt_base
.. autofunction:: conformer_rnnt_base
conv_tasnet_base
~~~~~~~~~~~~~~~~
.. autofunction:: conv_tasnet_base
ConvEmformer
~~~~~~~~~~~~
......
from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [
"conformer_rnnt_base",
"conformer_rnnt_model",
"conv_tasnet_base",
"ConvEmformer",
]
from torchaudio.models import ConvTasNet
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
r"""Builds the non-causal version of ConvTasNet in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
[:footcite:`Luo_2019`].
The paramter settings follow the ones with the highest Si-SNR metirc score in the paper,
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
Args:
num_sources (int, optional): Number of sources in the output.
(Default: 2)
Returns:
ConvTasNet:
ConvTasNet model.
"""
return ConvTasNet(
num_sources=num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
msk_activate="relu",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment