Commit 30691a32 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

further cleaning

parent 31a54850
......@@ -7,3 +7,7 @@ from .modeling_parler_tts import (
)
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
\ No newline at end of file
......@@ -45,6 +45,11 @@ from transformers.utils import (
)
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
......
......@@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Train a text-to-speech model using 🤗 Transformers Seq2SeqTrainer"""
""" Train Parler-TTS using 🤗 Accelerate"""
import functools
import json
import logging
import os
import re
import sys
import shutil
import warnings
import math
import time
from multiprocess import set_start_method
from datetime import timedelta
......@@ -48,26 +44,18 @@ import transformers
from transformers import (
AutoFeatureExtractor,
AutoModel,
AutoModelWithLMHead,
AutoProcessor,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from transformers.trainer_utils import is_main_process
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel
from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
from transformers import AutoModel
from accelerate import Accelerator
......@@ -77,7 +65,6 @@ from accelerate.utils.memory import release_memory
from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment