"docs/vscode:/vscode.git/clone" did not exist on "96f5467b6adfe59a2b63bbcf1356ec736bb06769"
Unverified Commit c5594795 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Postprocessing refactor all others (#3337)



* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent a757b2db
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
...@@ -136,6 +138,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -136,6 +138,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self): def enable_vae_slicing(self):
...@@ -474,6 +477,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -474,6 +477,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -918,17 +926,17 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -918,17 +926,17 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 14. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload() self.final_offload_hook.offload()
# 15. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import PIL import PIL
...@@ -21,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -21,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.import_utils import is_accelerate_available
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
...@@ -138,6 +140,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -138,6 +140,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self): def enable_vae_slicing(self):
...@@ -429,6 +432,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -429,6 +432,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -814,16 +822,17 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -814,16 +822,17 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload() self.final_offload_hook.offload()
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -363,6 +363,11 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -363,6 +363,11 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -26,6 +27,7 @@ from transformers import ( ...@@ -26,6 +27,7 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
...@@ -88,6 +90,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -88,6 +90,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None and ( if self.text_unet is not None and (
"dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
...@@ -329,6 +332,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -329,6 +332,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -572,12 +580,12 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -572,12 +580,12 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ import torch ...@@ -21,6 +22,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
...@@ -71,6 +73,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -71,6 +73,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -189,6 +192,11 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -189,6 +192,11 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -414,12 +422,12 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -414,12 +422,12 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 9. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
...@@ -76,6 +78,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -76,6 +78,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None: if self.text_unet is not None:
self._swap_unet_attention_blocks() self._swap_unet_attention_blocks()
...@@ -246,6 +249,11 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -246,6 +249,11 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -488,12 +496,12 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -488,12 +496,12 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -28,17 +28,18 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( ...@@ -28,17 +28,18 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = AltDiffusionPipeline pipeline_class = AltDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -123,6 +123,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase): ...@@ -123,6 +123,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer.model_max_length = 77 tokenizer.model_max_length = 77
init_image = self.dummy_image.to(device) init_image = self.dummy_image.to(device)
init_image = init_image / 2 + 0.5
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionImg2ImgPipeline( alt_pipe = AltDiffusionImg2ImgPipeline(
...@@ -134,7 +135,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase): ...@@ -134,7 +135,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True)
alt_pipe = alt_pipe.to(device) alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None) alt_pipe.set_progress_bar_config(disable=None)
......
...@@ -38,6 +38,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -38,6 +38,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PaintByExamplePipeline pipeline_class = PaintByExamplePipeline
params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -26,13 +26,13 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_d ...@@ -26,13 +26,13 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_d
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = CycleDiffusionPipeline pipeline_class = CycleDiffusionPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
"negative_prompt", "negative_prompt",
...@@ -42,6 +42,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -42,6 +42,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
} }
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"}) batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -42,17 +42,18 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device ...@@ -42,17 +42,18 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ...models.test_models_unet_2d_condition import create_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -35,13 +35,14 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -35,13 +35,14 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -33,16 +33,21 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow ...@@ -33,16 +33,21 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionImageVariationPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionImageVariationPipeline pipeline_class = StableDiffusionImageVariationPipeline
params = IMAGE_VARIATION_PARAMS params = IMAGE_VARIATION_PARAMS
batch_params = IMAGE_VARIATION_BATCH_PARAMS batch_params = IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -36,16 +36,19 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow ...@@ -36,16 +36,19 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionInpaintPipeline pipeline_class = StableDiffusionInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -35,16 +35,21 @@ from diffusers.utils import floats_tensor, load_image, slow, torch_device ...@@ -35,16 +35,21 @@ from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionInstructPix2PixPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionInstructPix2PixPipeline pipeline_class = StableDiffusionInstructPix2PixPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -31,18 +31,19 @@ from diffusers import ( ...@@ -31,18 +31,19 @@ from diffusers import (
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps @skip_mps
class StableDiffusionModelEditingPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionModelEditingPipeline pipeline_class = StableDiffusionModelEditingPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -32,18 +32,19 @@ from diffusers import ( ...@@ -32,18 +32,19 @@ from diffusers import (
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps @skip_mps
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPanoramaPipeline pipeline_class = StableDiffusionPanoramaPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -36,17 +36,20 @@ from diffusers.utils import floats_tensor, load_numpy, slow, torch_device ...@@ -36,17 +36,20 @@ from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps @skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline pipeline_class = StableDiffusionPix2PixZeroPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -29,17 +29,18 @@ from diffusers import ( ...@@ -29,17 +29,18 @@ from diffusers import (
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionSAGPipeline pipeline_class = StableDiffusionSAGPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_cpu_offload = False test_cpu_offload = False
def get_dummy_components(self): def get_dummy_components(self):
......
...@@ -35,17 +35,18 @@ from diffusers import ( ...@@ -35,17 +35,18 @@ from diffusers import (
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment