Unverified Commit f5e5f348 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[modular] add tests for qwen modular (#12585)

* add tests for qwenimage modular.

* qwenimage edit.

* qwenimage edit plus.

* empty

* align with the latest structure

* up

* up

* reason

* up

* fix multiple issues.

* up

* up

* fix

* up

* make it similar to the original pipeline.
parent 093cd3f0
...@@ -132,6 +132,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks): ...@@ -132,6 +132,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [ return [
InputParam("latents"),
InputParam(name="height"), InputParam(name="height"),
InputParam(name="width"), InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1), InputParam(name="num_images_per_prompt", default=1),
...@@ -196,11 +197,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks): ...@@ -196,11 +197,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
if block_state.latents is None:
block_state.latents = randn_tensor( block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype shape, generator=block_state.generator, device=device, dtype=block_state.dtype
) )
block_state.latents = components.pachifier.pack_latents(block_state.latents) block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
...@@ -549,8 +550,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): ...@@ -549,8 +550,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
block_state.width // components.vae_scale_factor // 2, block_state.width // components.vae_scale_factor // 2,
) )
] ]
* block_state.batch_size ] * block_state.batch_size
]
block_state.txt_seq_lens = ( block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
) )
......
...@@ -74,8 +74,9 @@ class QwenImageDecoderStep(ModularPipelineBlocks): ...@@ -74,8 +74,9 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
vae_scale_factor = components.vae_scale_factor
block_state.latents = components.pachifier.unpack_latents( block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
) )
block_state.latents = block_state.latents.to(components.vae.dtype) block_state.latents = block_state.latents.to(components.vae.dtype)
......
...@@ -503,6 +503,8 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): ...@@ -503,6 +503,8 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds: if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or "" negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
...@@ -627,6 +629,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): ...@@ -627,6 +629,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
device=device, device=device,
) )
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds: if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " " negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
...@@ -679,6 +683,8 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): ...@@ -679,6 +683,8 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
device=device, device=device,
) )
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds: if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " " negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
......
...@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin): ...@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
config_name = "config.json" config_name = "config.json"
@register_to_config @register_to_config
def __init__( def __init__(self, patch_size: int = 2):
self,
patch_size: int = 2,
):
super().__init__() super().__init__()
def pack_latents(self, latents): def pack_latents(self, latents):
......
...@@ -55,6 +55,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin): ...@@ -55,6 +55,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
} }
return inputs return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin): class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline pipeline_class = FluxModularPipeline
...@@ -118,6 +121,9 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin): ...@@ -118,6 +121,9 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_float16_inference(self):
super().test_float16_inference(8e-2)
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin): class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline pipeline_class = FluxKontextModularPipeline
...@@ -170,3 +176,6 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin): ...@@ -170,3 +176,6 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
image_slices.append(image[0, -3:, -3:, -1].flatten()) image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_float16_inference(self):
super().test_float16_inference(9e-2)
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL
import pytest
from diffusers.modular_pipelines import (
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
)
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs
def test_guider_cfg(self):
super().test_guider_cfg(7e-5)
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
def test_guider_cfg(self):
super().test_guider_cfg(1e-3)
...@@ -25,7 +25,7 @@ from diffusers.loaders import ModularIPAdapterMixin ...@@ -25,7 +25,7 @@ from diffusers.loaders import ModularIPAdapterMixin
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
enable_full_determinism() enable_full_determinism()
...@@ -37,13 +37,11 @@ class SDXLModularTesterMixin: ...@@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
""" """
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2): def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
sd_pipe = self.get_pipeline() sd_pipe = self.get_pipeline().to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
image = sd_pipe(**inputs, output="images") image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1].cpu()
assert image.shape == expected_image_shape assert image.shape == expected_image_shape
max_diff = torch.abs(image_slice.flatten() - expected_slice).max() max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
...@@ -110,7 +108,7 @@ class SDXLModularIPAdapterTesterMixin: ...@@ -110,7 +108,7 @@ class SDXLModularIPAdapterTesterMixin:
pipe = blocks.init_pipeline(self.repo) pipe = blocks.init_pipeline(self.repo)
pipe.load_components(torch_dtype=torch.float32) pipe.load_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim") cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter # forward pass without ip adapter
...@@ -219,9 +217,7 @@ class SDXLModularControlNetTesterMixin: ...@@ -219,9 +217,7 @@ class SDXLModularControlNetTesterMixin:
# compare against static slices and that can be shaky (with a VVVV low probability). # compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet # forward pass without controlnet
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
...@@ -251,9 +247,7 @@ class SDXLModularControlNetTesterMixin: ...@@ -251,9 +247,7 @@ class SDXLModularControlNetTesterMixin:
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self): def test_controlnet_cfg(self):
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied # forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0) guider = ClassifierFreeGuidance(guidance_scale=1.0)
...@@ -273,35 +267,11 @@ class SDXLModularControlNetTesterMixin: ...@@ -273,35 +267,11 @@ class SDXLModularControlNetTesterMixin:
assert max_diff > 1e-2, "Output with CFG must be different from normal inference" assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTesterMixin:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class TestSDXLModularPipelineFast( class TestSDXLModularPipelineFast(
SDXLModularTesterMixin, SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin, SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin, SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin, ModularGuiderTesterMixin,
ModularPipelineTesterMixin, ModularPipelineTesterMixin,
): ):
"""Test cases for Stable Diffusion XL modular pipeline fast tests.""" """Test cases for Stable Diffusion XL modular pipeline fast tests."""
...@@ -335,18 +305,7 @@ class TestSDXLModularPipelineFast( ...@@ -335,18 +305,7 @@ class TestSDXLModularPipelineFast(
self._test_stable_diffusion_xl_euler( self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape, expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor( expected_slice=torch.tensor(
[ [0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
device=torch_device,
), ),
expected_max_diff=1e-2, expected_max_diff=1e-2,
) )
...@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast( ...@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
SDXLModularTesterMixin, SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin, SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin, SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin, ModularGuiderTesterMixin,
ModularPipelineTesterMixin, ModularPipelineTesterMixin,
): ):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests.""" """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
...@@ -400,20 +359,7 @@ class TestSDXLImg2ImgModularPipelineFast( ...@@ -400,20 +359,7 @@ class TestSDXLImg2ImgModularPipelineFast(
def test_stable_diffusion_xl_euler(self): def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler( self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape, expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor( expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
device=torch_device,
),
expected_max_diff=1e-2, expected_max_diff=1e-2,
) )
...@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests( ...@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
SDXLModularTesterMixin, SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin, SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin, SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin, ModularGuiderTesterMixin,
ModularPipelineTesterMixin, ModularPipelineTesterMixin,
): ):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests.""" """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
......
...@@ -2,22 +2,17 @@ import gc ...@@ -2,22 +2,17 @@ import gc
import tempfile import tempfile
from typing import Callable, Union from typing import Callable, Union
import pytest
import torch import torch
import diffusers import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.utils import logging from diffusers.utils import logging
from ..testing_utils import ( from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
@require_torch
class ModularPipelineTesterMixin: class ModularPipelineTesterMixin:
""" """
It provides a set of common tests for each modular pipeline, It provides a set of common tests for each modular pipeline,
...@@ -32,20 +27,9 @@ class ModularPipelineTesterMixin: ...@@ -32,20 +27,9 @@ class ModularPipelineTesterMixin:
# Canonical parameters that are passed to `__call__` regardless # Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common # of the type of pipeline. They are always optional and have common
# sense default values. # sense default values.
optional_params = frozenset( optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# this is modular specific: generator needs to be a intermediate input because it's mutable # this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset( intermediate_params = frozenset(["generator"])
[
"generator",
]
)
def get_generator(self, seed=0): def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed) generator = torch.Generator("cpu").manual_seed(seed)
...@@ -121,6 +105,7 @@ class ModularPipelineTesterMixin: ...@@ -121,6 +105,7 @@ class ModularPipelineTesterMixin:
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_components(torch_dtype=torch_dtype) pipeline.load_components(torch_dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=None)
return pipeline return pipeline
def test_pipeline_call_signature(self): def test_pipeline_call_signature(self):
...@@ -138,9 +123,7 @@ class ModularPipelineTesterMixin: ...@@ -138,9 +123,7 @@ class ModularPipelineTesterMixin:
_check_for_parameters(self.optional_params, optional_parameters, "optional") _check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True): def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
inputs["generator"] = self.get_generator(0) inputs["generator"] = self.get_generator(0)
...@@ -179,9 +162,8 @@ class ModularPipelineTesterMixin: ...@@ -179,9 +162,8 @@ class ModularPipelineTesterMixin:
batch_size=2, batch_size=2,
expected_max_diff=1e-4, expected_max_diff=1e-4,
): ):
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
# Reset generator in case it is has been used in self.get_dummy_inputs # Reset generator in case it is has been used in self.get_dummy_inputs
...@@ -219,11 +201,9 @@ class ModularPipelineTesterMixin: ...@@ -219,11 +201,9 @@ class ModularPipelineTesterMixin:
def test_float16_inference(self, expected_max_diff=5e-2): def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline() pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32) pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline() pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16) pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs # Reset generator in case it is used inside dummy inputs
...@@ -237,19 +217,16 @@ class ModularPipelineTesterMixin: ...@@ -237,19 +217,16 @@ class ModularPipelineTesterMixin:
fp16_inputs["generator"] = self.get_generator(0) fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images") output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor): output = output.cpu()
output = output.cpu() output_fp16 = output_fp16.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference" assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator @require_accelerator
def test_to_device(self): def test_to_device(self):
pipe = self.get_pipeline() pipe = self.get_pipeline().to("cpu")
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [ model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device") component.device.type for component in pipe.components.values() if hasattr(component, "device")
] ]
...@@ -264,30 +241,23 @@ class ModularPipelineTesterMixin: ...@@ -264,30 +241,23 @@ class ModularPipelineTesterMixin:
) )
def test_inference_is_not_nan_cpu(self): def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline() pipe = self.get_pipeline().to("cpu")
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
output = pipe(**self.get_dummy_inputs(), output="images") output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN" assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator @require_accelerator
def test_inference_is_not_nan(self): def test_inference_is_not_nan(self):
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
output = pipe(**self.get_dummy_inputs(), output="images") output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN" assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self): def test_num_images_per_prompt(self):
pipe = self.get_pipeline() pipe = self.get_pipeline().to(torch_device)
if "num_images_per_prompt" not in pipe.blocks.input_names: if "num_images_per_prompt" not in pipe.blocks.input_names:
return pytest.mark.skip("Skipping test as `num_images_per_prompt` is not present in input names.")
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2] batch_sizes = [1, 2]
num_images_per_prompts = [1, 2] num_images_per_prompts = [1, 2]
...@@ -342,3 +312,25 @@ class ModularPipelineTesterMixin: ...@@ -342,3 +312,25 @@ class ModularPipelineTesterMixin:
image_slices.append(image[0, -3:, -3:, -1].flatten()) image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class ModularGuiderTesterMixin:
def test_guider_cfg(self, expected_max_diff=1e-2):
pipe = self.get_pipeline().to(torch_device)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment