init_model_300M.py 1.89 KB
Newer Older
Yoach Lacombe's avatar
Yoach Lacombe committed
1
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
Yoach Lacombe's avatar
Yoach Lacombe committed
2
from transformers import AutoConfig
Yoach Lacombe's avatar
Yoach Lacombe committed
3
4
5
6
7
8
import os
TMP_DIR = "./tmp/artefacts/"

text_model = "google/flan-t5-base"
encodec_version = "ylacombe/dac_44khZ_8kbps"
num_codebooks = 9
9

Yoach Lacombe's avatar
Yoach Lacombe committed
10
11
12
13
14
15
16

t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)

encodec_vocab_size = encodec.codebook_size


Yoach Lacombe's avatar
Yoach Lacombe committed
17
decoder_config = ParlerTTSDecoderConfig(
Yoach Lacombe's avatar
Yoach Lacombe committed
18
19
20
21
22
    vocab_size=encodec_vocab_size + 64,  # + 64 instead of +1 to have a multiple of 64
    max_position_embeddings=4096,  # 30 s = 2580
    num_hidden_layers=24,
    ffn_dim=4096,
    num_attention_heads=16,
23
24
25
    layerdrop=0.0,
    use_cache=True,
    activation_function="gelu",
Yoach Lacombe's avatar
Yoach Lacombe committed
26
27
    hidden_size=1024,
    dropout=0.1,
Yoach Lacombe's avatar
Yoach Lacombe committed
28
29
    attention_dropout=0.0,
    activation_dropout=0.0,
Yoach Lacombe's avatar
Yoach Lacombe committed
30
31
    pad_token_id=encodec_vocab_size,
    eos_token_id=encodec_vocab_size,
Yoach Lacombe's avatar
Yoach Lacombe committed
32
    bos_token_id=encodec_vocab_size + 1,
Yoach Lacombe's avatar
Yoach Lacombe committed
33
    num_codebooks=num_codebooks,
34
35
)

Yoach Lacombe's avatar
Yoach Lacombe committed
36

Yoach Lacombe's avatar
Yoach Lacombe committed
37
decoder = ParlerTTSForCausalLM(decoder_config)
Yoach Lacombe's avatar
Yoach Lacombe committed
38
decoder.save_pretrained(os.path.join(TMP_DIR, "decoder"))
39
40


Yoach Lacombe's avatar
Yoach Lacombe committed
41
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
Yoach Lacombe's avatar
Yoach Lacombe committed
42
43
    text_encoder_pretrained_model_name_or_path=text_model,
    audio_encoder_pretrained_model_name_or_path=encodec_version,
Yoach Lacombe's avatar
Yoach Lacombe committed
44
    decoder_pretrained_model_name_or_path=os.path.join(TMP_DIR, "decoder"),
Yoach Lacombe's avatar
Yoach Lacombe committed
45
    vocab_size=t5.vocab_size,
46
47
48
)

# set the appropriate bos/pad token ids
Yoach Lacombe's avatar
Yoach Lacombe committed
49
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
Yoach Lacombe's avatar
Yoach Lacombe committed
50
51
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
52
53
54

# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
Yoach Lacombe's avatar
Yoach Lacombe committed
55
model.generation_config.do_sample = True  # True
Yoach Lacombe's avatar
Yoach Lacombe committed
56
model.generation_config.guidance_scale = 1  # 3.0
57

Yoach Lacombe's avatar
Yoach Lacombe committed
58
59

model.save_pretrained(os.path.join(TMP_DIR,"stable-speech-untrained-300M/"))