Unverified Commit 77fc197f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Speed up test and remove kwargs from call (#1446)

Remove kwargs from call
parent edf22c05
...@@ -445,7 +445,6 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -445,7 +445,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -484,7 +484,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -484,7 +484,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -528,7 +528,6 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -528,7 +528,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -289,7 +289,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -289,7 +289,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jit: bool = False, jit: bool = False,
debug: bool = False, debug: bool = False,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.array = None,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -205,7 +205,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -205,7 +205,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -241,7 +241,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -241,7 +241,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -259,7 +259,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -259,7 +259,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -241,7 +241,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -241,7 +241,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -444,7 +444,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -444,7 +444,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -342,7 +342,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -342,7 +342,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -493,7 +493,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -493,7 +493,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -566,7 +566,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -566,7 +566,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -492,7 +492,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -492,7 +492,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -546,7 +546,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -546,7 +546,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
sld_threshold: Optional[float] = 0.01, sld_threshold: Optional[float] = 0.01,
sld_momentum_scale: Optional[float] = 0.3, sld_momentum_scale: Optional[float] = 0.3,
sld_mom_beta: Optional[float] = 0.4, sld_mom_beta: Optional[float] = 0.4,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
......
...@@ -765,18 +765,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -765,18 +765,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prompt = "hey" prompt = "hey"
output = sd_pipe(prompt, number_of_steps=1, output_type="np") output = sd_pipe(prompt, num_inference_steps=1, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (64, 64) assert image_shape == (64, 64)
output = sd_pipe(prompt, number_of_steps=1, height=96, width=96, output_type="np") output = sd_pipe(prompt, num_inference_steps=1, height=96, width=96, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (96, 96) assert image_shape == (96, 96)
config = dict(sd_pipe.unet.config) config = dict(sd_pipe.unet.config)
config["sample_size"] = 96 config["sample_size"] = 96
sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device) sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device)
output = sd_pipe(prompt, number_of_steps=1, output_type="np") output = sd_pipe(prompt, num_inference_steps=1, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (192, 192) assert image_shape == (192, 192)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment