Unverified Commit 2ed7e05f authored by Sebastian's avatar Sebastian Committed by GitHub
Browse files

Improve performance of fast test by reducing down blocks (#5290)

* Reduce number of down block channels

* Remove debug code

* Set new excepted image slice values for sdxl euler test
parent cc2c4ae7
...@@ -51,7 +51,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -51,7 +51,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(2, 4),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
...@@ -66,6 +66,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -66,6 +66,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64, cross_attention_dim=64,
norm_num_groups=1,
) )
scheduler = EulerDiscreteScheduler( scheduler = EulerDiscreteScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -144,7 +145,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -144,7 +145,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5873, 0.6128, 0.4797, 0.5122, 0.5674, 0.4639, 0.5227, 0.5149, 0.4747]) expected_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.47])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment