Commit cb44e483 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix vocab_size in dummy init

parent 226fe07f
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig
decoder_config = StableSpeechDecoderConfig( decoder_config = StableSpeechDecoderConfig(
max_position_embeddings=1024, max_position_embeddings=2048,
num_hidden_layers=2, num_hidden_layers=2,
ffn_dim=256, ffn_dim=256,
num_attention_heads=4, num_attention_heads=4,
...@@ -24,11 +24,13 @@ decoder = StableSpeechForCausalLM(decoder_config) ...@@ -24,11 +24,13 @@ decoder = StableSpeechForCausalLM(decoder_config)
decoder.save_pretrained("/home/yoach/dataspeech/artefacts/decoder/") decoder.save_pretrained("/home/yoach/dataspeech/artefacts/decoder/")
t5 = AutoConfig.from_pretrained("t5-base")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path="t5-base", text_encoder_pretrained_model_name_or_path="t5-base",
audio_encoder_pretrained_model_name_or_path="facebook/encodec_32khz", audio_encoder_pretrained_model_name_or_path="facebook/encodec_32khz",
decoder_pretrained_model_name_or_path="/home/yoach/dataspeech/artefacts/decoder/", decoder_pretrained_model_name_or_path="/home/yoach/dataspeech/artefacts/decoder/",
vocab_size = t5.vocab_size
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment