Unverified Commit 2558977b authored by James R T's avatar James R T Committed by GitHub
Browse files

Add callback parameters for Stable Diffusion pipelines (#521)



* Add callback parameters for Stable Diffusion pipelines
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Lint code with `black --preview`
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Refactor callback implementation for Stable Diffusion pipelines

* Fix missing imports
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Fix documentation format
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Add kwargs parameter to standardize with other pipelines
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Modify Stable Diffusion pipeline callback parameters
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Remove useless imports
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Change types for timestep and onnx latents

* Fix docstring style

* Return decode_latents and run_safety_checker back into __call__

* Remove unused imports

* Add intermediate state tests for Stable Diffusion pipelines
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>

* Fix intermediate state tests for Stable Diffusion pipelines
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>
Signed-off-by: default avatarJames R T <jamestiotio@gmail.com>
parent 5156acc4
import inspect import inspect
import warnings import warnings
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
...@@ -122,6 +122,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -122,6 +122,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -159,6 +161,12 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -159,6 +161,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
...@@ -178,6 +186,14 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -178,6 +186,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings # get prompt text embeddings
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
...@@ -277,14 +293,16 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -277,14 +293,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
......
import inspect import inspect
import warnings import warnings
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -133,6 +133,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -133,6 +133,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -170,6 +173,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -170,6 +173,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
...@@ -188,6 +197,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -188,6 +197,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
...@@ -265,6 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -265,6 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays # Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand # It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
...@@ -295,14 +313,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -295,14 +313,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
......
import inspect import inspect
import warnings import warnings
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -149,6 +149,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -149,6 +149,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -190,6 +193,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -190,6 +193,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
...@@ -208,6 +217,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -208,6 +217,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
...@@ -297,7 +314,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -297,7 +314,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays # Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand # It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
...@@ -331,14 +350,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -331,14 +350,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
# scale and decode the image latents with vae # call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
......
import inspect import inspect
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -56,6 +56,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -56,6 +56,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -68,6 +70,14 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -68,6 +70,14 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings # get prompt text embeddings
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
...@@ -151,14 +161,18 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -151,14 +161,18 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae latents = np.array(latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0] image = self.vae_decoder(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1) image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
......
...@@ -1435,3 +1435,177 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1435,3 +1435,177 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 512, 512, 3) assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010]) expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Andromeda galaxy in a bottle"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 51
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 96)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 38
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 38
@slow
def test_stable_diffusion_onnx_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.6254, -0.2742, -1.0710, 0.2296, -1.1683, 0.6913, -2.0605, -0.0682, 0.9700]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)
prompt = "Andromeda galaxy in a bottle"
np.random.seed(0)
pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called
assert number_of_steps == 51
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment