init_model.py 2.05 KB
Newer Older
Yoach Lacombe's avatar
Yoach Lacombe committed
1
from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
Yoach Lacombe's avatar
Yoach Lacombe committed
2
3
4
from transformers import T5Config, EncodecConfig
from transformers import AutoConfig

Yoach Lacombe's avatar
Yoach Lacombe committed
5

6
from transformers import AutoConfig, AutoModel
Yoach Lacombe's avatar
Yoach Lacombe committed
7
from parler_tts import DACConfig, DACModel
8
9
10
11

AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)

Yoach Lacombe's avatar
Yoach Lacombe committed
12
text_model = "google-t5/t5-small"
13
14
encodec_version = "ylacombe/dac_44khZ_8kbps"
num_codebooks = 9
Yoach Lacombe's avatar
Yoach Lacombe committed
15
16
17
18
19
20
21
22


t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)

encodec_vocab_size = encodec.codebook_size


Yoach Lacombe's avatar
Yoach Lacombe committed
23
decoder_config = ParlerTTSDecoderConfig(
Yoach Lacombe's avatar
Yoach Lacombe committed
24
25
    vocab_size=encodec_vocab_size + 1,
    max_position_embeddings=3000,  # 30 s = 2580
Yoach Lacombe's avatar
Yoach Lacombe committed
26
    num_hidden_layers=12,
Yoach Lacombe's avatar
Yoach Lacombe committed
27
28
29
30
31
32
33
34
35
    ffn_dim=4096,
    num_attention_heads=16,
    layerdrop=0.0,
    use_cache=True,
    activation_function="gelu",
    hidden_size=1024,
    dropout=0.0,
    attention_dropout=0.0,
    activation_dropout=0.0,
Yoach Lacombe's avatar
Yoach Lacombe committed
36
37
    pad_token_id=encodec_vocab_size,
    eos_token_id=encodec_vocab_size,
Yoach Lacombe's avatar
Yoach Lacombe committed
38
    bos_token_id=encodec_vocab_size + 1,
Yoach Lacombe's avatar
Yoach Lacombe committed
39
40
    num_codebooks=num_codebooks,
)
Yoach Lacombe's avatar
Yoach Lacombe committed
41

Yoach Lacombe's avatar
Yoach Lacombe committed
42

Yoach Lacombe's avatar
Yoach Lacombe committed
43
decoder = ParlerTTSForCausalLM(decoder_config)
Yoach Lacombe's avatar
Yoach Lacombe committed
44
45
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")

Yoach Lacombe's avatar
Yoach Lacombe committed
46

Yoach Lacombe's avatar
Yoach Lacombe committed
47
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
Yoach Lacombe's avatar
Yoach Lacombe committed
48
49
50
    text_encoder_pretrained_model_name_or_path=text_model,
    audio_encoder_pretrained_model_name_or_path=encodec_version,
    decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
Yoach Lacombe's avatar
Yoach Lacombe committed
51
    vocab_size=t5.vocab_size,
Yoach Lacombe's avatar
Yoach Lacombe committed
52
53
54
)

# set the appropriate bos/pad token ids
Yoach Lacombe's avatar
Yoach Lacombe committed
55
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
Yoach Lacombe's avatar
Yoach Lacombe committed
56
57
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
Yoach Lacombe's avatar
Yoach Lacombe committed
58
59
60

# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
Yoach Lacombe's avatar
Yoach Lacombe committed
61
62
model.generation_config.do_sample = False  # True
model.generation_config.guidance_scale = 1  # 3.0
Yoach Lacombe's avatar
Yoach Lacombe committed
63

Yoach Lacombe's avatar
Yoach Lacombe committed
64

Yoach Lacombe's avatar
Yoach Lacombe committed
65
model.save_pretrained("/raid/yoach/tmp/artefacts/small-stable-speech-untrained/")