Commit 62854588 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Adding pipeline changes, factory functions to HDemucs (#2547)

Summary:
Factory functions have been added to HDemucs class and test the implementation within the testing files.

Pull Request resolved: https://github.com/pytorch/audio/pull/2547

Reviewed By: carolineechen

Differential Revision: D37948600

Pulled By: skim0514

fbshipit-source-id: 7ac4e4a71519450cfbbc24ff7d7e70521f676040
parent af6ebbae
...@@ -82,7 +82,7 @@ fi ...@@ -82,7 +82,7 @@ fi
( (
set -x set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -87,7 +87,8 @@ esac ...@@ -87,7 +87,8 @@ esac
'scipy==1.7.3' \ 'scipy==1.7.3' \
transformers \ transformers \
unidecode \ unidecode \
'protobuf<4.21.0' 'protobuf<4.21.0' \
demucs
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -28,13 +28,37 @@ ConvEmformer ...@@ -28,13 +28,37 @@ ConvEmformer
.. automethod:: infer .. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs HDemucs
~~~~~~~ ^^^^^^^
.. autoclass:: HDemucs .. autoclass:: HDemucs
.. automethod:: forward .. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
References References
~~~~~~~~~~ ~~~~~~~~~~
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.hdemucs_test_impl import HDemucsTests from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
class HDemucsFloat32CPUTest(HDemucsTests, PytorchTestCase): class HDemucsFloat32CPUTest(HDemucsTests, CompareHDemucsOriginal, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cpu") device = torch.device("cpu")
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.hdemucs_test_impl import HDemucsTests from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
@skipIfNoCuda @skipIfNoCuda
class HDemucsFloat32GPUTest(HDemucsTests, PytorchTestCase): class HDemucsFloat32GPUTest(HDemucsTests, CompareHDemucsOriginal, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cuda") device = torch.device("cuda")
import itertools
from typing import List
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low
from torchaudio_unittest.common_utils import TestBaseMixin from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase
def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6, sample_rate: int = 44100):
return HDemucs(sources, nfft=n_fft, depth=depth, sample_rate=sample_rate)
def _get_inputs(sample_rate: int, device: torch.device, batch_size: int = 1, duration: int = 10, channels: int = 2):
sample = torch.rand(batch_size, channels, duration * sample_rate, dtype=torch.float32, device=device)
return sample
def _get_hdemucs_model(sources): SOURCE_OPTIONS = [
return HDemucs(sources) (["bass", "drums", "other", "vocals"],),
(["bass", "drums", "other"],),
(["bass", "vocals"],),
(["vocals"],),
]
SOURCES_OUTPUT_CONFIG = parameterized.expand(SOURCE_OPTIONS)
class HDemucsTests(TestBaseMixin): class HDemucsTests(TestBaseMixin):
def _get_inputs(self, duration: int, channels: int, batch_size: int, sample_rate: int): @parameterized.expand(list(itertools.product(SOURCE_OPTIONS, [(1024, 5), (2048, 6), (4096, 6)])))
sample = torch.rand(batch_size, channels, duration * sample_rate, dtype=torch.float32, device=self.device) def test_hdemucs_output_shape(self, sources, nfft_bundle):
return sample
@parameterized.expand(
[
(["bass", "drums", "other", "vocals"],),
(["bass", "drums", "other"],),
(["bass", "vocals"],),
(["vocals"],),
]
)
def test_hdemucs_output_shape(self, sources):
r"""Feed tensors with specific shape to HDemucs and validate r"""Feed tensors with specific shape to HDemucs and validate
that it outputs with a tensor with expected shape. that it outputs with a tensor with expected shape.
""" """
batch_size = 1
duration = 10 duration = 10
channels = 2 channels = 2
batch_size = 1
sample_rate = 44100 sample_rate = 44100
nfft = nfft_bundle[0]
depth = nfft_bundle[1]
model = _get_hdemucs_model(sources).to(self.device).eval() model = _get_hdemucs_model(sources, nfft, depth).to(self.device).eval()
inputs = self._get_inputs(duration, channels, batch_size, sample_rate) inputs = _get_inputs(sample_rate, self.device, batch_size, duration, channels)
split_sample = model(inputs) split_sample = model(inputs)
...@@ -106,3 +115,46 @@ class HDemucsTests(TestBaseMixin): ...@@ -106,3 +115,46 @@ class HDemucsTests(TestBaseMixin):
assert z.size() == (batch_size, chout, t * stride) assert z.size() == (batch_size, chout, t * stride)
assert y.size() == (batch_size, chin, t) assert y.size() == (batch_size, chin, t)
@skipIfNoModule("demucs")
class CompareHDemucsOriginal(TorchaudioTestCase):
"""Test the process of importing the models from demucs.
Test methods in this test suite will check to assure correctness in factory functions,
comparing with original hybrid demucs
"""
def _get_original_model(self, sources: List[str], nfft: int, depth: int):
from demucs import hdemucs as original
original = original.HDemucs(sources, nfft=nfft, depth=depth)
return original
def _assert_equal_models(self, factory_hdemucs, depth, nfft, sample_rate, sources):
torch.random.manual_seed(0)
original_hdemucs = self._get_original_model(sources, nfft, depth).to(self.device).eval()
inputs = _get_inputs(sample_rate=sample_rate, device=self.device)
factory_output = factory_hdemucs(inputs)
original_output = original_hdemucs(inputs)
self.assertEqual(original_output, factory_output)
@SOURCES_OUTPUT_CONFIG
def test_import_recreate_low_model(self, sources):
sample_rate = 8000
nfft = 1024
depth = 5
torch.random.manual_seed(0)
factory_hdemucs = hdemucs_low(sources, sample_rate=sample_rate).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
@SOURCES_OUTPUT_CONFIG
def test_import_recreate_high_model(self, sources):
sample_rate = 44100
nfft = 4096
depth = 6
torch.random.manual_seed(0)
factory_hdemucs = hdemucs_high(sources, sample_rate=sample_rate).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base from .conv_tasnet import conv_tasnet_base
from .hdemucs import HDemucs from .hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
...@@ -9,4 +9,7 @@ __all__ = [ ...@@ -9,4 +9,7 @@ __all__ = [
"conv_tasnet_base", "conv_tasnet_base",
"ConvEmformer", "ConvEmformer",
"HDemucs", "HDemucs",
"hdemucs_high",
"hdemucs_medium",
"hdemucs_low",
] ]
...@@ -243,9 +243,9 @@ class _HDecLayer(torch.nn.Module): ...@@ -243,9 +243,9 @@ class _HDecLayer(torch.nn.Module):
if self.empty: if self.empty:
self.rewrite = nn.Identity() self.rewrite = nn.Identity()
self.norm1 = nn.Identity() self.norm1 = nn.Identity()
return else:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
self.norm1 = norm_fn(2 * chin) self.norm1 = norm_fn(2 * chin)
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length): def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
r"""Forward pass for decoding layer. r"""Forward pass for decoding layer.
...@@ -929,3 +929,57 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = ...@@ -929,3 +929,57 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int =
_, length = x.shape _, length = x.shape
other.append(length) other.append(length)
return x.view(other) return x.view(other)
def hdemucs_low(sources: List[str], sample_rate: int) -> HDemucs:
r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ.
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend lower sample rates.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=1024, depth=5, sample_rate=sample_rate)
def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs:
r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ
.. note::
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
not compatible with the original implementation in https://github.com/facebookresearch/demucs
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend middle tier sample rates (16kHz).
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=2048, depth=6, sample_rate=sample_rate)
def hdemucs_high(sources: List[str], sample_rate: int) -> HDemucs:
r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates,
and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around
44.1-48 kHZ
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend higher/standard sample rates (44.1kHz, 48kHz).
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=4096, depth=6, sample_rate=sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment