Commit 30691a32 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

further cleaning

parent 31a54850
...@@ -7,3 +7,7 @@ from .modeling_parler_tts import ( ...@@ -7,3 +7,7 @@ from .modeling_parler_tts import (
) )
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
\ No newline at end of file
...@@ -45,6 +45,11 @@ from transformers.utils import ( ...@@ -45,6 +45,11 @@ from transformers.utils import (
) )
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
......
...@@ -14,17 +14,13 @@ ...@@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Train a text-to-speech model using 🤗 Transformers Seq2SeqTrainer""" """ Train Parler-TTS using 🤗 Accelerate"""
import functools
import json
import logging import logging
import os import os
import re import re
import sys import sys
import shutil import shutil
import warnings
import math
import time import time
from multiprocess import set_start_method from multiprocess import set_start_method
from datetime import timedelta from datetime import timedelta
...@@ -48,26 +44,18 @@ import transformers ...@@ -48,26 +44,18 @@ import transformers
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModel, AutoModel,
AutoModelWithLMHead,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
) )
from transformers.trainer_utils import is_main_process
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline from transformers import pipeline
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel from transformers import AutoModel
from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator from accelerate import Accelerator
...@@ -77,7 +65,6 @@ from accelerate.utils.memory import release_memory ...@@ -77,7 +65,6 @@ from accelerate.utils.memory import release_memory
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration, ParlerTTSForConditionalGeneration,
ParlerTTSConfig, ParlerTTSConfig,
apply_delay_pattern_mask,
build_delay_pattern_mask, build_delay_pattern_mask,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment