Commit 608b8ea6 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Hybrid Demucs model implementation (#2506)

Summary:
Draft PR with initial model implementation with minor changes from previous implementation

Pull Request resolved: https://github.com/pytorch/audio/pull/2506

Reviewed By: nateanl

Differential Revision: D37762671

Pulled By: skim0514

fbshipit-source-id: b7dc0a6ef725d6ae6d76c23c882623f7d339977c
parent e2641452
...@@ -28,6 +28,13 @@ ConvEmformer ...@@ -28,6 +28,13 @@ ConvEmformer
.. automethod:: infer .. automethod:: infer
HDemucs
~~~~~~~
.. autoclass:: HDemucs
.. automethod:: forward
References References
~~~~~~~~~~ ~~~~~~~~~~
......
...@@ -382,3 +382,9 @@ ...@@ -382,3 +382,9 @@
journal={arXiv preprint arXiv:1706.08612}, journal={arXiv preprint arXiv:1706.08612},
year={2017} year={2017}
} }
@inproceedings{defossez2021hybrid,
title={Hybrid Spectrogram and Waveform Source Separation},
author={D{\'e}fossez, Alexandre},
booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation},
year={2021}
}
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.hdemucs_test_impl import HDemucsTests
class HDemucsFloat32CPUTest(HDemucsTests, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.hdemucs_test_impl import HDemucsTests
@skipIfNoCuda
class HDemucsFloat32GPUTest(HDemucsTests, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
import torch
from parameterized import parameterized
from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs
from torchaudio_unittest.common_utils import TestBaseMixin
def _get_hdemucs_model(sources):
return HDemucs(sources)
class HDemucsTests(TestBaseMixin):
def _get_inputs(self, duration: int, channels: int, batch_size: int, sample_rate: int):
sample = torch.rand(batch_size, channels, duration * sample_rate, dtype=torch.float32, device=self.device)
return sample
@parameterized.expand(
[
(["bass", "drums", "other", "vocals"],),
(["bass", "drums", "other"],),
(["bass", "vocals"],),
(["vocals"],),
]
)
def test_hdemucs_output_shape(self, sources):
r"""Feed tensors with specific shape to HDemucs and validate
that it outputs with a tensor with expected shape.
"""
batch_size = 1
duration = 10
channels = 2
sample_rate = 44100
model = _get_hdemucs_model(sources).to(self.device).eval()
inputs = self._get_inputs(duration, channels, batch_size, sample_rate)
split_sample = model(inputs)
assert split_sample.shape == (batch_size, len(sources), channels, duration * sample_rate)
def test_encoder_output_shape_frequency(self):
r"""Feed tensors with specific shape to HDemucs Decoder and validate
that it outputs with a tensor with expected shape for frequency domain.
"""
batch_size = 1
chin, chout = 4, 48
f_bins = 2048
t = 800
stride = 4
model = _HEncLayer(chin, chout).to(self.device).eval()
x = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
out = model(x)
assert out.size() == (batch_size, chout, f_bins / stride, t)
def test_decoder_output_shape_frequency(self):
r"""Feed tensors with specific shape to HDemucs Decoder and validate
that it outputs with a tensor with expected shape for frequency domain.
"""
batch_size = 1
chin, chout = 96, 48
f_bins = 128
t = 800
stride = 4
model = _HDecLayer(chin, chout).to(self.device).eval()
x = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
skip = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
z, y = model(x, skip, t)
assert z.size() == (batch_size, chout, f_bins * stride, t)
assert y.size() == (batch_size, chin, f_bins, t)
def test_encoder_output_shape_time(self):
r"""Feed tensors with specific shape to HDemucs Decoder and validate
that it outputs with a tensor with expected shape for time domain.
"""
batch_size = 1
chin, chout = 4, 48
t = 800
stride = 4
model = _HEncLayer(chin, chout, freq=False).to(self.device).eval()
x = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
out = model(x)
assert out.size() == (batch_size, chout, t / stride)
def test_decoder_output_shape_time(self):
r"""Feed tensors with specific shape to HDemucs Decoder and validate
that it outputs with a tensor with expected shape for time domain.
"""
batch_size = 1
chin, chout = 96, 48
t = 800
stride = 4
model = _HDecLayer(chin, chout, freq=False).to(self.device).eval()
x = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
skip = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
z, y = model(x, skip, t * stride)
assert z.size() == (batch_size, chout, t * stride)
assert y.size() == (batch_size, chin, t)
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base from .conv_tasnet import conv_tasnet_base
from .hdemucs import HDemucs
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
...@@ -7,4 +8,5 @@ __all__ = [ ...@@ -7,4 +8,5 @@ __all__ = [
"conformer_rnnt_model", "conformer_rnnt_model",
"conv_tasnet_base", "conv_tasnet_base",
"ConvEmformer", "ConvEmformer",
"HDemucs",
] ]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment