Commit ec0e3a80 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Move hybrid demucs model out of prototype (#2668)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2668

Reviewed By: nateanl, mthrok

Differential Revision: D39433671

Pulled By: carolineechen

fbshipit-source-id: 3545a5b4019832861c34fd8c05e5f8600fd80d5c
parent 1ddb70f9
......@@ -41,6 +41,37 @@ Emformer
.. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
RNN-T
~~~~~
......
......@@ -28,37 +28,6 @@ ConvEmformer
.. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
References
~~~~~~~~~~
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
from torchaudio_unittest.models.hdemucs.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
class HDemucsFloat32CPUTest(HDemucsTests, CompareHDemucsOriginal, PytorchTestCase):
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
from torchaudio_unittest.models.hdemucs.hdemucs_test_impl import CompareHDemucsOriginal, HDemucsTests
@skipIfNoCuda
......
......@@ -3,7 +3,7 @@ from typing import List
import torch
from parameterized import parameterized
from torchaudio.prototype.models.hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low
from torchaudio.models._hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low
from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase
......
from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .conformer import Conformer
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
......@@ -50,4 +51,8 @@ __all__ = [
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
"HDemucs",
"hdemucs_low",
"hdemucs_medium",
"hdemucs_high",
]
from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base
from .hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [
......@@ -8,8 +7,4 @@ __all__ = [
"conformer_rnnt_model",
"conv_tasnet_base",
"ConvEmformer",
"HDemucs",
"hdemucs_high",
"hdemucs_medium",
"hdemucs_low",
]
import torchaudio
functions = ["HDemucs", "hdemucs_high", "hdemucs_medium", "hdemucs_low"]
def __getattr__(name: str):
if name in functions:
import warnings
warnings.warn(
f"{__name__}.{name} has been moved to torchaudio.models.hdemucs",
DeprecationWarning,
)
return getattr(torchaudio.models, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return functions
......@@ -5,7 +5,8 @@ from typing import Callable
import torch
import torchaudio
from torchaudio.prototype.models import conv_tasnet_base, hdemucs_high
from torchaudio.models import hdemucs_high
from torchaudio.prototype.models import conv_tasnet_base
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment