Unverified Commit 92f15f5b authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Model CPU offload fix for BLIPDiffusion (#5174)

cpu offload fix for blip diffusion
parent 22b19d57
...@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder. Position of the context token in the text encoder.
""" """
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__( def __init__(
self, self,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
...@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def encode_prompt(self, query_embeds, prompt): def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context # embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens max_len -= self.qformer.config.num_query_tokens
...@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
truncation=True, truncation=True,
max_length=max_len, max_length=max_len,
return_tensors="pt", return_tensors="pt",
).to(self.device) ).to(device)
batch_size = query_embeds.shape[0] batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
...@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Returns: Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
device = self._execution_device
reference_image = self.image_processor.preprocess( reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"] )["pixel_values"]
reference_image = reference_image.to(self.device) reference_image = reference_image.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
...@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
prompt_reps=prompt_reps, prompt_reps=prompt_reps,
) )
query_embeds = self.get_query_embeddings(reference_image, source_subject_category) query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt) text_embeddings = self.encode_prompt(query_embeds, prompt, device)
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings max_length = self.text_encoder.text_model.config.max_position_embeddings
...@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
uncond_embeddings = self.text_encoder( uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device), input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None, ctx_embeddings=None,
)[0] )[0]
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
generator=generator, generator=generator,
latents=latents, latents=latents,
dtype=self.unet.dtype, dtype=self.unet.dtype,
device=self.device, device=device,
) )
# set timesteps # set timesteps
extra_set_kwargs = {} extra_set_kwargs = {}
...@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
t, t,
latents, latents,
)["prev_sample"] )["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder. Position of the context token in the text encoder.
""" """
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__( def __init__(
self, self,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
...@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def encode_prompt(self, query_embeds, prompt): def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context # embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens max_len -= self.qformer.config.num_query_tokens
...@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
truncation=True, truncation=True,
max_length=max_len, max_length=max_len,
return_tensors="pt", return_tensors="pt",
).to(self.device) ).to(device)
batch_size = query_embeds.shape[0] batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
...@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Returns: Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
device = self._execution_device
reference_image = self.image_processor.preprocess( reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"] )["pixel_values"]
reference_image = reference_image.to(self.device) reference_image = reference_image.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
...@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
prompt_reps=prompt_reps, prompt_reps=prompt_reps,
) )
query_embeds = self.get_query_embeddings(reference_image, source_subject_category) query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt) text_embeddings = self.encode_prompt(query_embeds, prompt, device)
# 3. unconditional embedding # 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
uncond_embeddings = self.text_encoder( uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device), input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None, ctx_embeddings=None,
)[0] )[0]
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
generator=generator, generator=generator,
latents=latents, latents=latents,
dtype=self.unet.dtype, dtype=self.unet.dtype,
device=self.device, device=device,
) )
# set timesteps # set timesteps
extra_set_kwargs = {} extra_set_kwargs = {}
...@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment