Unverified Commit 2e0d489a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pix2Pix] Add utility function (#2385)

* [Pix2Pix] Add utility function

* improve

* update

* Apply suggestions from code review

* uP

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
parent abd5dcbb
...@@ -542,6 +542,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -542,6 +542,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
"""Constructs the edit direction to steer the image generation process semantically.""" """Constructs the edit direction to steer the image generation process semantically."""
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
@torch.no_grad()
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor:
num_prompts = len(prompt)
embeds = []
for i in range(0, num_prompts, batch_size):
prompt_slice = prompt[i : i + batch_size]
input_ids = self.tokenizer(
prompt_slice,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
input_ids = input_ids.to(self.text_encoder.device)
embeds.append(self.text_encoder(input_ids)[0])
return torch.cat(embeds, dim=0).mean(0)[None]
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment