Unverified Commit a937e1b5 authored by Pi Esposito's avatar Pi Esposito Committed by GitHub
Browse files

add load textual inversion embeddings to stable diffusion (#2009)



* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add case for parsing embedding from auto1111 UI format
Co-authored-by: default avatarEvan Jones <evan.a.jones3@gmail.com>
Co-authored-by: default avatarAna Tamais <aninhamoraestamais@gmail.com>

* fix style after rebase

* move textual inversion mixin to loaders

* move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline)

* update dummy class name

* addressed allo comments

* fix old dangling import

* fix style

* proposal

* remove bogus

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* finish

* make style

* up

* fix code quality

* fix code quality - again

* fix code quality - 3

* fix alt diffusion code quality

* fix model editing pipeline

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Finish

---------
Co-authored-by: default avatarEvan Jones <evan.a.jones3@gmail.com>
Co-authored-by: default avatarAna Tamais <aninhamoraestamais@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 1d033a95
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableUnCLIPPipeline(DiffusionPipeline): class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
""" """
Pipeline for text-to-image generation using stable unCLIP. Pipeline for text-to-image generation using stable unCLIP.
...@@ -367,6 +368,10 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -367,6 +368,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -427,6 +432,10 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -427,6 +432,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.import_utils import is_accelerate_available
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -60,7 +61,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -60,7 +61,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
""" """
Pipeline for text-guided image to image generation using stable unCLIP. Pipeline for text-guided image to image generation using stable unCLIP.
...@@ -267,6 +268,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -267,6 +268,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -327,6 +332,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -327,6 +332,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - ...@@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -
return images return images
class TextToVideoSDPipeline(DiffusionPipeline): class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
...@@ -256,6 +257,10 @@ class TextToVideoSDPipeline(DiffusionPipeline): ...@@ -256,6 +257,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -316,6 +321,10 @@ class TextToVideoSDPipeline(DiffusionPipeline): ...@@ -316,6 +321,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -37,6 +37,8 @@ from .doc_utils import replace_example_docstring ...@@ -37,6 +37,8 @@ from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import ( from .hub_utils import (
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
_add_variant,
_get_model_file,
extract_commit_hash, extract_commit_hash,
http_user_agent, http_user_agent,
) )
......
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class TextualInversionLoaderMixin(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -18,16 +18,30 @@ import os ...@@ -18,16 +18,30 @@ import os
import re import re
import sys import sys
import traceback import traceback
import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
)
from packaging import version
from requests import HTTPError
from .. import __version__ from .. import __version__
from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
)
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
_flax_version, _flax_version,
...@@ -215,3 +229,130 @@ if cache_version < 1: ...@@ -215,3 +229,130 @@ if cache_version < 1:
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
"the directory exists and can be written to." "the directory exists and can be written to."
) )
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
...@@ -21,6 +21,7 @@ import unittest ...@@ -21,6 +21,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -886,6 +887,32 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -886,6 +887,32 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
assert mem_bytes_slicing < mem_bytes_offloaded assert mem_bytes_slicing < mem_bytes_offloaded
assert mem_bytes_slicing < 3 * 10**9 assert mem_bytes_slicing < 3 * 10**9
def test_stable_diffusion_textual_inversion(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
a111_file_neg = hf_hub_download(
"hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
)
pipe.load_textual_inversion(a111_file)
pipe.load_textual_inversion(a111_file_neg)
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(1)
prompt = "An logo of a turtle in strong Style-Winter with <low-poly-hd-logos-icons>"
neg_prompt = "Style-Winter-neg"
image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 5e-3
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -362,6 +362,97 @@ class DownloadTests(unittest.TestCase): ...@@ -362,6 +362,97 @@ class DownloadTests(unittest.TestCase):
diffusers.utils.import_utils._safetensors_available = True diffusers.utils.import_utils._safetensors_available = True
def test_text_inversion_download(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
num_tokens = len(pipe.tokenizer)
# single token load local
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<*>": torch.ones((32,))}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname)
token = pipe.tokenizer.convert_tokens_to_ids("<*>")
assert token == num_tokens, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>"
prompt = "hey <*>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# single token load local with weight name
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<**>": 2 * torch.ones((1, 32))}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin")
token = pipe.tokenizer.convert_tokens_to_ids("<**>")
assert token == num_tokens + 1, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>"
prompt = "hey <**>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi token load
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname)
token = pipe.tokenizer.convert_tokens_to_ids("<***>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2")
assert token == num_tokens + 2, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2"
prompt = "hey <***>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi token load a1111
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {
"string_to_param": {
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
},
"name": "<****>",
}
torch.save(ten, os.path.join(tmpdirname, "a1111.bin"))
pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin")
token = pipe.tokenizer.convert_tokens_to_ids("<****>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2")
assert token == num_tokens + 5, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2"
prompt = "hey <****>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
class CustomPipelineTests(unittest.TestCase): class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self): def test_load_custom_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment