Unverified Commit c4d66200 authored by __mo_san__'s avatar __mo_san__ Committed by GitHub
Browse files

make-fast-test-for-StableDiffusionControlNetPipeline-faster (#5292)



* decrease UNet2DConditionModel & ControlNetModel blocks

* decrease UNet2DConditionModel & ControlNetModel blocks

* decrease even more blocks & number of norm groups

* decrease vae block out channels and n of norm goups

* fix code style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2ed7e05f
...@@ -119,7 +119,7 @@ class ControlNetPipelineFastTests( ...@@ -119,7 +119,7 @@ class ControlNetPipelineFastTests(
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
...@@ -127,15 +127,17 @@ class ControlNetPipelineFastTests( ...@@ -127,15 +127,17 @@ class ControlNetPipelineFastTests(
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
norm_num_groups=1,
) )
torch.manual_seed(0) torch.manual_seed(0)
controlnet = ControlNetModel( controlnet = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
in_channels=4, in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
) )
torch.manual_seed(0) torch.manual_seed(0)
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
...@@ -147,12 +149,13 @@ class ControlNetPipelineFastTests( ...@@ -147,12 +149,13 @@ class ControlNetPipelineFastTests(
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[4, 8],
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
norm_num_groups=2,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
...@@ -230,7 +233,7 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -230,7 +233,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
...@@ -238,6 +241,7 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -238,6 +241,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
norm_num_groups=1,
) )
torch.manual_seed(0) torch.manual_seed(0)
...@@ -247,23 +251,25 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -247,23 +251,25 @@ class StableDiffusionMultiControlNetPipelineFastTests(
m.bias.data.fill_(1.0) m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel( controlnet1 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
in_channels=4, in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
) )
controlnet1.controlnet_down_blocks.apply(init_weights) controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
controlnet2 = ControlNetModel( controlnet2 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
in_channels=4, in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
) )
controlnet2.controlnet_down_blocks.apply(init_weights) controlnet2.controlnet_down_blocks.apply(init_weights)
...@@ -277,12 +283,13 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -277,12 +283,13 @@ class StableDiffusionMultiControlNetPipelineFastTests(
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[4, 8],
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
norm_num_groups=2,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
...@@ -415,7 +422,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -415,7 +422,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
...@@ -423,6 +430,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -423,6 +430,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
norm_num_groups=1,
) )
torch.manual_seed(0) torch.manual_seed(0)
...@@ -432,12 +440,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -432,12 +440,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
m.bias.data.fill_(1.0) m.bias.data.fill_(1.0)
controlnet = ControlNetModel( controlnet = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
in_channels=4, in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
) )
controlnet.controlnet_down_blocks.apply(init_weights) controlnet.controlnet_down_blocks.apply(init_weights)
...@@ -451,12 +460,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -451,12 +460,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[4, 8],
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
norm_num_groups=2,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment