Unverified Commit ed00ead3 authored by Jincheng Miao's avatar Jincheng Miao Committed by GitHub
Browse files

[Community Pipelines] add textual inversion support for stable_diffusion_ipex (#5571)

parent f0b2f6ce
...@@ -21,6 +21,7 @@ from packaging import version ...@@ -21,6 +21,7 @@ from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.loaders import TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
...@@ -61,7 +62,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -61,7 +62,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionIPEXPipeline(DiffusionPipeline): class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion on IPEX. Pipeline for text-to-image generation using Stable Diffusion on IPEX.
...@@ -454,6 +455,10 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline): ...@@ -454,6 +455,10 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -514,6 +519,10 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline): ...@@ -514,6 +519,10 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment