Unverified Commit 92f15f5b authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Model CPU offload fix for BLIPDiffusion (#5174)

cpu offload fix for blip diffusion
parent 22b19d57
......@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
......@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
......@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
......@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
......@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
......@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
......@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
......@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
......
......@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
......@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
......@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
......@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
......@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
......@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
......@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
......@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment