Unverified Commit c5594795 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Postprocessing refactor all others (#3337)



* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent a757b2db
......@@ -13,12 +13,14 @@
# limitations under the License.
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
......@@ -136,6 +138,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
......@@ -474,6 +477,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -918,17 +926,17 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 14. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 15. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
import PIL
......@@ -21,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from diffusers.utils.import_utils import is_accelerate_available
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
......@@ -138,6 +140,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
......@@ -429,6 +432,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -814,16 +822,17 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
......
......@@ -363,6 +363,11 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
......@@ -26,6 +27,7 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
......@@ -88,6 +90,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None and (
"dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
......@@ -329,6 +332,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -572,12 +580,12 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -21,6 +22,7 @@ import torch
import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
......@@ -71,6 +73,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
......@@ -189,6 +192,11 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -414,12 +422,12 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 9. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
......
......@@ -13,12 +13,14 @@
# limitations under the License.
import inspect
import warnings
from typing import Callable, List, Optional, Union
import torch
import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
......@@ -76,6 +78,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None:
self._swap_unet_attention_blocks()
......@@ -246,6 +249,11 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -488,12 +496,12 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
......
......@@ -28,17 +28,18 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = AltDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -123,6 +123,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer.model_max_length = 77
init_image = self.dummy_image.to(device)
init_image = init_image / 2 + 0.5
# make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionImg2ImgPipeline(
......@@ -134,7 +135,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)
......
......@@ -38,6 +38,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PaintByExamplePipeline
params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -26,13 +26,13 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_d
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = CycleDiffusionPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
"negative_prompt",
......@@ -42,6 +42,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -42,17 +42,18 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ...models.test_models_unet_2d_condition import create_lora_layers
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -35,13 +35,14 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -33,16 +33,21 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionImageVariationPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionImageVariationPipeline
params = IMAGE_VARIATION_PARAMS
batch_params = IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -36,16 +36,19 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -35,16 +35,21 @@ from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionInstructPix2PixPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionInstructPix2PixPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -31,18 +31,19 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class StableDiffusionModelEditingPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionModelEditingPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -32,18 +32,19 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPanoramaPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -36,17 +36,20 @@ from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
@classmethod
def setUpClass(cls):
......
......@@ -29,17 +29,18 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionSAGPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_cpu_offload = False
def get_dummy_components(self):
......
......@@ -35,17 +35,18 @@ from diffusers import (
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment