"vscode:/vscode.git/clone" did not exist on "9bca40296e3f00fb26597a0f4cfe2fdfd2ad2fd2"
Commit 338e3104 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Change docstring for easier understanding (#2570)

Summary:
Edit factory function's docstrings.

Pull Request resolved: https://github.com/pytorch/audio/pull/2570

Reviewed By: carolineechen

Differential Revision: D38250369

Pulled By: skim0514

fbshipit-source-id: fa777e37d7cc517cf4ff1842d5585bf36558f50a
parent 39b6343d
...@@ -7,8 +7,8 @@ from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs, ...@@ -7,8 +7,8 @@ from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs,
from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase
def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6, sample_rate: int = 44100): def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6):
return HDemucs(sources, nfft=n_fft, depth=depth, sample_rate=sample_rate) return HDemucs(sources, nfft=n_fft, depth=depth)
def _get_inputs(sample_rate: int, device: torch.device, batch_size: int = 1, duration: int = 10, channels: int = 2): def _get_inputs(sample_rate: int, device: torch.device, batch_size: int = 1, duration: int = 10, channels: int = 2):
...@@ -146,7 +146,7 @@ class CompareHDemucsOriginal(TorchaudioTestCase): ...@@ -146,7 +146,7 @@ class CompareHDemucsOriginal(TorchaudioTestCase):
depth = 5 depth = 5
torch.random.manual_seed(0) torch.random.manual_seed(0)
factory_hdemucs = hdemucs_low(sources, sample_rate=sample_rate).to(self.device).eval() factory_hdemucs = hdemucs_low(sources).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources) self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
@SOURCES_OUTPUT_CONFIG @SOURCES_OUTPUT_CONFIG
...@@ -156,5 +156,5 @@ class CompareHDemucsOriginal(TorchaudioTestCase): ...@@ -156,5 +156,5 @@ class CompareHDemucsOriginal(TorchaudioTestCase):
depth = 6 depth = 6
torch.random.manual_seed(0) torch.random.manual_seed(0)
factory_hdemucs = hdemucs_high(sources, sample_rate=sample_rate).to(self.device).eval() factory_hdemucs = hdemucs_high(sources).to(self.device).eval()
self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources) self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)
...@@ -303,7 +303,8 @@ class HDemucs(torch.nn.Module): ...@@ -303,7 +303,8 @@ class HDemucs(torch.nn.Module):
Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* [:footcite:`defossez2021hybrid`]. Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* [:footcite:`defossez2021hybrid`].
Args: Args:
sources (List[str]): list of source names. sources (List[str]): list of source names. List can contain the following source
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
audio_channels (int, optional): input/output audio channels. (Default: 2) audio_channels (int, optional): input/output audio channels. (Default: 2)
channels (int, optional): initial number of hidden channels. (Default: 48) channels (int, optional): initial number of hidden channels. (Default: 48)
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2) growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
...@@ -328,8 +329,6 @@ class HDemucs(torch.nn.Module): ...@@ -328,8 +329,6 @@ class HDemucs(torch.nn.Module):
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4) dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4) dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4) dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
sample_rate (int, optional): sample rate, serving as metadata not actually used (Default: 44100)
segment (int, optional): segment size (Default: 40)
""" """
def __init__( def __init__(
...@@ -355,8 +354,6 @@ class HDemucs(torch.nn.Module): ...@@ -355,8 +354,6 @@ class HDemucs(torch.nn.Module):
dconv_attn: int = 4, dconv_attn: int = 4,
dconv_lstm: int = 4, dconv_lstm: int = 4,
dconv_init: float = 1e-4, dconv_init: float = 1e-4,
sample_rate: int = 44100,
segment: int = 4 * 10,
): ):
super().__init__() super().__init__()
self.depth = depth self.depth = depth
...@@ -367,8 +364,6 @@ class HDemucs(torch.nn.Module): ...@@ -367,8 +364,6 @@ class HDemucs(torch.nn.Module):
self.context = context self.context = context
self.stride = stride self.stride = stride
self.channels = channels self.channels = channels
self.sample_rate = sample_rate
self.segment = segment
self.hop_length = self.nfft // 4 self.hop_length = self.nfft // 4
self.freq_emb = None self.freq_emb = None
...@@ -480,7 +475,7 @@ class HDemucs(torch.nn.Module): ...@@ -480,7 +475,7 @@ class HDemucs(torch.nn.Module):
raise ValueError("Hop length must be nfft // 4") raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl)) le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3 pad = hl // 2 * 3
x = F.pad(x, [pad, pad + le * hl - x.shape[-1]], mode="reflect") x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :] z = _spectro(x, nfft, hl)[..., :-1, :]
if z.shape[-1] != le + 4: if z.shape[-1] != le + 4:
...@@ -498,6 +493,16 @@ class HDemucs(torch.nn.Module): ...@@ -498,6 +493,16 @@ class HDemucs(torch.nn.Module):
x = x[..., pad : pad + length] x = x[..., pad : pad + length]
return x return x
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
Add extra zero padding around in order for padding to not break."""
length = x.shape[-1]
if mode == "reflect":
max_pad = max(padding_left, padding_right)
if length <= max_pad:
x = F.pad(x, (0, max_pad - length + 1))
return F.pad(x, (padding_left, padding_right), mode, value)
def _magnitude(self, z): def _magnitude(self, z):
# move the complex dimension to the channel one. # move the complex dimension to the channel one.
B, C, Fr, T = z.shape B, C, Fr, T = z.shape
...@@ -953,23 +958,22 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = ...@@ -953,23 +958,22 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int =
return x.view(other) return x.view(other)
def hdemucs_low(sources: List[str], sample_rate: int) -> HDemucs: def hdemucs_low(sources: List[str]) -> HDemucs:
r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ. parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ.
Args: Args:
sources (List[str]): Sources to use for audio split sources (List[str]): See :py:func:`HDemucs`.
sample_rate (int): Serves as metadata, recommend lower sample rates.
Returns: Returns:
HDemucs: HDemucs:
HDemucs model. HDemucs model.
""" """
return HDemucs(sources=sources, nfft=1024, depth=5, sample_rate=sample_rate) return HDemucs(sources=sources, nfft=1024, depth=5)
def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs: def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ
...@@ -979,29 +983,27 @@ def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs: ...@@ -979,29 +983,27 @@ def hdemucs_medium(sources: List[str], sample_rate: int) -> HDemucs:
not compatible with the original implementation in https://github.com/facebookresearch/demucs not compatible with the original implementation in https://github.com/facebookresearch/demucs
Args: Args:
sources (List[str]): Sources to use for audio split sources (List[str]): See :py:func:`HDemucs`.
sample_rate (int): Serves as metadata, recommend middle tier sample rates (16kHz).
Returns: Returns:
HDemucs: HDemucs:
HDemucs model. HDemucs model.
""" """
return HDemucs(sources=sources, nfft=2048, depth=6, sample_rate=sample_rate) return HDemucs(sources=sources, nfft=2048, depth=6)
def hdemucs_high(sources: List[str], sample_rate: int) -> HDemucs: def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates, r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates,
and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around
44.1-48 kHZ 44.1-48 kHZ
Args: Args:
sources (List[str]): Sources to use for audio split sources (List[str]): See :py:func:`HDemucs`.
sample_rate (int): Serves as metadata, recommend higher/standard sample rates (44.1kHz, 48kHz).
Returns: Returns:
HDemucs: HDemucs:
HDemucs model. HDemucs model.
""" """
return HDemucs(sources=sources, nfft=4096, depth=6, sample_rate=sample_rate) return HDemucs(sources=sources, nfft=4096, depth=6)
...@@ -78,7 +78,7 @@ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_ ...@@ -78,7 +78,7 @@ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle( HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt", _model_path="models/hdemucs_high_trained.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"], sample_rate=44100), _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100, _sample_rate=44100,
) )
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment