init_dummy_model.py 1.91 KB
Newer Older
1
2
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from transformers import T5Config, EncodecConfig
Yoach Lacombe's avatar
Yoach Lacombe committed
3
from transformers import AutoConfig
4

Yoach Lacombe's avatar
Yoach Lacombe committed
5
6
7
8
9
10
11
12
13
14
text_model = "google-t5/t5-small"
encodec_version = "facebook/encodec_24khz"
num_codebooks = 8

t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)

encodec_vocab_size = encodec.codebook_size


15
decoder_config = StableSpeechDecoderConfig(
Yoach Lacombe's avatar
Yoach Lacombe committed
16
    vocab_size=encodec_vocab_size+1,
Yoach Lacombe's avatar
Yoach Lacombe committed
17
    max_position_embeddings=2048,
Yoach Lacombe's avatar
Yoach Lacombe committed
18
19
20
    num_hidden_layers=4,
    ffn_dim=512,
    num_attention_heads=8,
21
22
23
    layerdrop=0.0,
    use_cache=True,
    activation_function="gelu",
Yoach Lacombe's avatar
Yoach Lacombe committed
24
25
26
27
    hidden_size=512,
    dropout=0.0,
    attention_dropout=0.0,
    activation_dropout=0.0,
Yoach Lacombe's avatar
Yoach Lacombe committed
28
29
30
31
    pad_token_id=encodec_vocab_size,
    eos_token_id=encodec_vocab_size,
    bos_token_id=encodec_vocab_size+1,
    num_codebooks=num_codebooks,
32
33
34
35
36
37
38
)
# TODO: ?? how to make it stop ?
        

        
decoder = StableSpeechForCausalLM(decoder_config)

Yoach Lacombe's avatar
Yoach Lacombe committed
39
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
40
41
42
43



model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
Yoach Lacombe's avatar
Yoach Lacombe committed
44
45
46
    text_encoder_pretrained_model_name_or_path=text_model,
    audio_encoder_pretrained_model_name_or_path=encodec_version,
    decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
Yoach Lacombe's avatar
Yoach Lacombe committed
47
    vocab_size = t5.vocab_size
48
49
50
)

# set the appropriate bos/pad token ids
Yoach Lacombe's avatar
Yoach Lacombe committed
51
52
53
model.generation_config.decoder_start_token_id = encodec_vocab_size+1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
54
55
56

# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
Yoach Lacombe's avatar
Yoach Lacombe committed
57
58
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
59

Yoach Lacombe's avatar
Yoach Lacombe committed
60
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-model/")