Unverified Commit 857c04cf authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pix2Pix0] Add utility function to get edit vector (#2383)

uP
parent 2e7a2865
...@@ -542,6 +542,20 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -542,6 +542,20 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
"""Constructs the edit direction to steer the image generation process semantically.""" """Constructs the edit direction to steer the image generation process semantically."""
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
@torch.no_grad()
def get_embeds(self, prompt: List[str]) -> torch.FloatTensor:
input_ids = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
embeds = self.text_encoder(input_ids)[0]
return embeds.mean(0)[None]
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment