Unverified Commit 435d37ce authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

[Tests] reduce the model size in the audioldm fast test (#7833)

chore: initial size reduction of models
parent 5915c298
...@@ -66,16 +66,17 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -66,16 +66,17 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(8, 16),
layers_per_block=2, layers_per_block=1,
norm_num_groups=8,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=(32, 64), cross_attention_dim=(8, 16),
class_embed_type="simple_projection", class_embed_type="simple_projection",
projection_class_embeddings_input_dim=32, projection_class_embeddings_input_dim=8,
class_embeddings_concat=True, class_embeddings_concat=True,
) )
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
...@@ -87,9 +88,10 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -87,9 +88,10 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[8, 16],
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
norm_num_groups=8,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
...@@ -98,14 +100,14 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -98,14 +100,14 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
text_encoder_config = ClapTextConfig( text_encoder_config = ClapTextConfig(
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
hidden_size=32, hidden_size=8,
intermediate_size=37, intermediate_size=37,
layer_norm_eps=1e-05, layer_norm_eps=1e-05,
num_attention_heads=4, num_attention_heads=1,
num_hidden_layers=5, num_hidden_layers=1,
pad_token_id=1, pad_token_id=1,
vocab_size=1000, vocab_size=1000,
projection_dim=32, projection_dim=8,
) )
text_encoder = ClapTextModelWithProjection(text_encoder_config) text_encoder = ClapTextModelWithProjection(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77) tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment