Unverified Commit 8ca179a0 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update free model hooks (#5680)

update free model hooks
parent 71f56c77
...@@ -1109,8 +1109,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -1109,8 +1109,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
nsfw_detected = None nsfw_detected = None
watermark_detected = None watermark_detected = None
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
self.unet_offload_hook.offload()
else: else:
# 10. Post-processing # 10. Post-processing
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -1119,9 +1117,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -1119,9 +1117,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# 11. Run safety checker # 11. Run safety checker
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU self.maybe_free_model_hooks()
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (image, nsfw_detected, watermark_detected) return (image, nsfw_detected, watermark_detected)
......
...@@ -388,6 +388,8 @@ class KandinskyPipeline(DiffusionPipeline): ...@@ -388,6 +388,8 @@ class KandinskyPipeline(DiffusionPipeline):
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
......
...@@ -321,6 +321,9 @@ class KandinskyCombinedPipeline(DiffusionPipeline): ...@@ -321,6 +321,9 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
callback_steps=callback_steps, callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -558,6 +561,9 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): ...@@ -558,6 +561,9 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
callback_steps=callback_steps, callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -593,7 +599,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline): ...@@ -593,7 +599,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
""" """
_load_connected_pipes = True _load_connected_pipes = True
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
def __init__( def __init__(
self, self,
...@@ -802,4 +808,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline): ...@@ -802,4 +808,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
callback_steps=callback_steps, callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -481,6 +481,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -481,6 +481,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
# 7. post-processing # 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
......
...@@ -616,6 +616,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline): ...@@ -616,6 +616,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
......
...@@ -527,7 +527,7 @@ class KandinskyPriorPipeline(DiffusionPipeline): ...@@ -527,7 +527,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
if negative_prompt is None: if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
self.maybe_free_model_hooks self.maybe_free_model_hooks()
else: else:
image_embeddings, zero_embeds = image_embeddings.chunk(2) image_embeddings, zero_embeds = image_embeddings.chunk(2)
......
...@@ -326,6 +326,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -326,6 +326,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
callback_on_step_end=callback_on_step_end, callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -572,6 +574,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -572,6 +574,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
callback_on_step_end=callback_on_step_end, callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -842,4 +846,6 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -842,4 +846,6 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs, **kwargs,
) )
self.maybe_free_model_hooks()
return outputs return outputs
...@@ -531,14 +531,10 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -531,14 +531,10 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
# if negative prompt has been defined, we retrieve split the image embedding into two # if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None: if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
else: else:
image_embeddings, zero_embeds = image_embeddings.chunk(2) image_embeddings, zero_embeds = image_embeddings.chunk(2)
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.maybe_free_model_hooks()
self.prior_hook.offload()
if output_type not in ["pt", "np"]: if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
......
...@@ -545,12 +545,10 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): ...@@ -545,12 +545,10 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
# if negative prompt has been defined, we retrieve split the image embedding into two # if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None: if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
else: else:
image_embeddings, zero_embeds = image_embeddings.chunk(2) image_embeddings, zero_embeds = image_embeddings.chunk(2)
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.prior_hook.offload() self.maybe_free_model_hooks()
if output_type not in ["pt", "np"]: if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
......
...@@ -918,6 +918,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -918,6 +918,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -1027,6 +1027,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -1027,6 +1027,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -846,6 +846,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -846,6 +846,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
image = latents image = latents
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -439,6 +439,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -439,6 +439,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -511,6 +511,8 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi ...@@ -511,6 +511,8 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -802,6 +802,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -802,6 +802,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -741,6 +741,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -741,6 +741,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment