Commit ec0e3a80 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Move hybrid demucs model out of prototype (#2668)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2668

Reviewed By: nateanl, mthrok

Differential Revision: D39433671

Pulled By: carolineechen

fbshipit-source-id: 3545a5b4019832861c34fd8c05e5f8600fd80d5c
parent 1ddb70f9
...@@ -41,6 +41,37 @@ Emformer ...@@ -41,6 +41,37 @@ Emformer
.. automethod:: infer .. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
RNN-T RNN-T
~~~~~ ~~~~~
......
...@@ -28,37 +28,6 @@ ConvEmformer ...@@ -28,37 +28,6 @@ ConvEmformer
.. automethod:: infer .. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
References References
~~~~~~~~~~ ~~~~~~~~~~
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests from torchaudio_unittest.models.hdemucs.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
class HDemucsFloat32CPUTest(HDemucsTests, CompareHDemucsOriginal, PytorchTestCase): class HDemucsFloat32CPUTest(HDemucsTests, CompareHDemucsOriginal, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests from torchaudio_unittest.models.hdemucs.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
@skipIfNoCuda @skipIfNoCuda
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low from torchaudio.models._hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low
from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase
......
from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .conformer import Conformer from .conformer import Conformer
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
...@@ -50,4 +51,8 @@ __all__ = [ ...@@ -50,4 +51,8 @@ __all__ = [
"RNNTBeamSearch", "RNNTBeamSearch",
"emformer_rnnt_base", "emformer_rnnt_base",
"emformer_rnnt_model", "emformer_rnnt_model",
"HDemucs",
"hdemucs_low",
"hdemucs_medium",
"hdemucs_high",
] ]
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base from .conv_tasnet import conv_tasnet_base
from .hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
...@@ -8,8 +7,4 @@ __all__ = [ ...@@ -8,8 +7,4 @@ __all__ = [
"conformer_rnnt_model", "conformer_rnnt_model",
"conv_tasnet_base", "conv_tasnet_base",
"ConvEmformer", "ConvEmformer",
"HDemucs",
"hdemucs_high",
"hdemucs_medium",
"hdemucs_low",
] ]
import torchaudio
functions = ["HDemucs", "hdemucs_high", "hdemucs_medium", "hdemucs_low"]
def __getattr__(name: str):
if name in functions:
import warnings
warnings.warn(
f"{__name__}.{name} has been moved to torchaudio.models.hdemucs",
DeprecationWarning,
)
return getattr(torchaudio.models, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return functions
...@@ -5,7 +5,8 @@ from typing import Callable ...@@ -5,7 +5,8 @@ from typing import Callable
import torch import torch
import torchaudio import torchaudio
from torchaudio.prototype.models import conv_tasnet_base, hdemucs_high from torchaudio.models import hdemucs_high
from torchaudio.prototype.models import conv_tasnet_base
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment