Commit 338e3104 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Change docstring for easier understanding (#2570)

Summary:
Edit factory function's docstrings.

Pull Request resolved: https://github.com/pytorch/audio/pull/2570

Reviewed By: carolineechen

Differential Revision: D38250369

Pulled By: skim0514

fbshipit-source-id: fa777e37d7cc517cf4ff1842d5585bf36558f50a
parent 39b6343d
......@@ -7,8 +7,8 @@ from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs,
from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase
def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6, sample_rate: int = 44100):
return HDemucs(sources, nfft=n_fft, depth=depth, sample_rate=sample_rate)
def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6):
return HDemucs(sources, nfft=n_fft, depth=depth)
def _get_inputs(sample_rate: int, device: torch.device, batch_size: int = 1, duration: int = 10, channels: int = 2):
......@@ -146,7 +146,7 @@ class CompareHDemucsOriginal(TorchaudioTestCase):
depth = 5
torch.random.manual_seed(0)
factory_hdemucs = hdemucs_low(sources, sample_rate=sample_rate).to(self.device).eval()
factory_hdemucs = hdemucs_low(sources).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
@SOURCES_OUTPUT_CONFIG
......@@ -156,5 +156,5 @@ class CompareHDemucsOriginal(TorchaudioTestCase):
depth = 6
torch.random.manual_seed(0)
factory_hdemucs = hdemucs_high(sources, sample_rate=sample_rate).to(self.device).eval()
factory_hdemucs = hdemucs_high(sources).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
......@@ -303,7 +303,8 @@ class HDemucs(torch.nn.Module):
Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* [:footcite:`defossez2021hybrid`].
Args:
sources (List[str]): list of source names.
sources (List[str]): list of source names. List can contain the following source
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
audio_channels (int, optional): input/output audio channels. (Default: 2)
channels (int, optional): initial number of hidden channels. (Default: 48)
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
......@@ -328,8 +329,6 @@ class HDemucs(torch.nn.Module):
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
sample_rate (int, optional): sample rate, serving as metadata not actually used (Default: 44100)
segment (int, optional): segment size (Default: 40)
"""
def __init__(
......@@ -355,8 +354,6 @@ class HDemucs(torch.nn.Module):
dconv_attn: int = 4,
dconv_lstm: int = 4,
dconv_init: float = 1e-4,
sample_rate: int = 44100,
segment: int = 4 * 10,
):
super().__init__()
self.depth = depth
......@@ -367,8 +364,6 @@ class HDemucs(torch.nn.Module):
self.context = context
self.stride = stride
self.channels = channels
self.sample_rate = sample_rate
self.segment = segment
self.hop_length = self.nfft // 4
self.freq_emb = None
......@@ -480,7 +475,7 @@ class HDemucs(torch.nn.Module):
raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = F.pad(x, [pad, pad + le * hl - x.shape[-1]], mode="reflect")
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :]
if z.shape[-1] != le + 4:
......@@ -498,6 +493,16 @@ class HDemucs(torch.nn.Module):
x = x[..., pad : pad + length]
return x
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
Add extra zero padding around in order for padding to not break."""
length = x.shape[-1]
if mode == "reflect":
max_pad = max(padding_left, padding_right)
if length <= max_pad:
x = F.pad(x, (0, max_pad - length + 1))
return F.pad(x, (padding_left, padding_right), mode, value)
def _magnitude(self, z):
# move the complex dimension to the channel one.
B, C, Fr, T = z.shape
......@@ -953,23 +958,22 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int =
return x.view(other)
def hdemucs_low(sources: List[str], sample_rate: int) -> HDemucs:
def hdemucs_low(sources: List[str]) -> HDemucs:
r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ.
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend lower sample rates.
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=1024, depth=5, sample_rate=sample_rate)
return HDemucs(sources=sources, nfft=1024, depth=5)
def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs:
def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ
......@@ -979,29 +983,27 @@ def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs:
not compatible with the original implementation in https://github.com/facebookresearch/demucs
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend middle tier sample rates (16kHz).
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=2048, depth=6, sample_rate=sample_rate)
return HDemucs(sources=sources, nfft=2048, depth=6)
def hdemucs_high(sources: List[str], sample_rate: int) -> HDemucs:
def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates,
and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around
44.1-48 kHZ
Args:
sources (List[str]): Sources to use for audio split
sample_rate (int): Serves as metadata, recommend higher/standard sample rates (44.1kHz, 48kHz).
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=4096, depth=6, sample_rate=sample_rate)
return HDemucs(sources=sources, nfft=4096, depth=6)
......@@ -78,7 +78,7 @@ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"], sample_rate=44100),
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment