Unverified Commit 145522cb authored by suzukimain's avatar suzukimain Committed by GitHub
Browse files

[Community] Enhanced `Model Search` (#10417)

* Added `auto_load_textual_inversion` and `auto_load_lora_weights`

* update README.md

* fix

* make quality

* Fix and `make style`
parent 23bc56a0
...@@ -82,31 +82,11 @@ pipeline = EasyPipelineForInpainting.from_huggingface( ...@@ -82,31 +82,11 @@ pipeline = EasyPipelineForInpainting.from_huggingface(
## Search Civitai and Huggingface ## Search Civitai and Huggingface
```python ```python
from pipeline_easy import (
search_huggingface,
search_civitai,
)
# Search Lora
Lora = search_civitai(
"Keyword_to_search_Lora",
model_type="LORA",
base_model = "SD 1.5",
download=True,
)
# Load Lora into the pipeline. # Load Lora into the pipeline.
pipeline.load_lora_weights(Lora) pipeline.auto_load_lora_weights("Detail Tweaker")
# Search TextualInversion
TextualInversion = search_civitai(
"EasyNegative",
model_type="TextualInversion",
base_model = "SD 1.5",
download=True
)
# Load TextualInversion into the pipeline. # Load TextualInversion into the pipeline.
pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
``` ```
### Search Civitai ### Search Civitai
......
# coding=utf-8 # coding=utf-8
# Copyright 2024 suzukimain # Copyright 2025 suzukimain
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import os import os
import re import re
import types
from collections import OrderedDict from collections import OrderedDict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from typing import Union from typing import Dict, List, Optional, Union
import requests import requests
import torch
from huggingface_hub import hf_api, hf_hub_download from huggingface_hub import hf_api, hf_hub_download
from huggingface_hub.file_download import http_get from huggingface_hub.file_download import http_get
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
...@@ -30,6 +32,7 @@ from diffusers.loaders.single_file_utils import ( ...@@ -30,6 +32,7 @@ from diffusers.loaders.single_file_utils import (
infer_diffusers_model_type, infer_diffusers_model_type,
load_single_file_checkpoint, load_single_file_checkpoint,
) )
from diffusers.pipelines.animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline
from diffusers.pipelines.auto_pipeline import ( from diffusers.pipelines.auto_pipeline import (
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
AutoPipelineForInpainting, AutoPipelineForInpainting,
...@@ -39,13 +42,18 @@ from diffusers.pipelines.controlnet import ( ...@@ -39,13 +42,18 @@ from diffusers.pipelines.controlnet import (
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetPipeline,
) )
from diffusers.pipelines.flux import FluxImg2ImgPipeline, FluxPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
) )
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
from diffusers.pipelines.stable_diffusion_xl import ( from diffusers.pipelines.stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
...@@ -59,46 +67,133 @@ logger = logging.get_logger(__name__) ...@@ -59,46 +67,133 @@ logger = logging.get_logger(__name__)
SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict( SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict(
[ [
("xl_base", StableDiffusionXLPipeline), ("animatediff_rgb", AnimateDiffPipeline),
("xl_refiner", StableDiffusionXLPipeline), ("animatediff_scribble", AnimateDiffPipeline),
("xl_inpaint", None), ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
("playground-v2-5", StableDiffusionXLPipeline), ("animatediff_v1", AnimateDiffPipeline),
("upscale", None), ("animatediff_v2", AnimateDiffPipeline),
("animatediff_v3", AnimateDiffPipeline),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetPipeline),
("controlnet_xl", StableDiffusionXLControlNetPipeline),
("controlnet_xl_large", StableDiffusionXLControlNetPipeline),
("controlnet_xl_mid", StableDiffusionXLControlNetPipeline),
("controlnet_xl_small", StableDiffusionXLControlNetPipeline),
("flux-depth", FluxPipeline),
("flux-dev", FluxPipeline),
("flux-fill", FluxPipeline),
("flux-schnell", FluxPipeline),
("hunyuan-video", None),
("inpainting", None), ("inpainting", None),
("inpainting_v2", None), ("inpainting_v2", None),
("controlnet", StableDiffusionControlNetPipeline), ("ltx-video", None),
("v2", StableDiffusionPipeline), ("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", StableDiffusionXLPipeline),
("sd3", StableDiffusion3Pipeline),
("sd35_large", StableDiffusion3Pipeline),
("sd35_medium", StableDiffusion3Pipeline),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", StableDiffusionPipeline), ("v1", StableDiffusionPipeline),
("v2", StableDiffusionPipeline),
("xl_base", StableDiffusionXLPipeline),
("xl_inpaint", None),
("xl_refiner", StableDiffusionXLPipeline),
] ]
) )
SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict( SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict(
[ [
("xl_base", StableDiffusionXLImg2ImgPipeline), ("animatediff_rgb", AnimateDiffPipeline),
("xl_refiner", StableDiffusionXLImg2ImgPipeline), ("animatediff_scribble", AnimateDiffPipeline),
("xl_inpaint", None), ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
("playground-v2-5", StableDiffusionXLImg2ImgPipeline), ("animatediff_v1", AnimateDiffPipeline),
("upscale", None), ("animatediff_v2", AnimateDiffPipeline),
("animatediff_v3", AnimateDiffPipeline),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetImg2ImgPipeline),
("controlnet_xl", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_large", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_mid", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_small", StableDiffusionXLControlNetImg2ImgPipeline),
("flux-depth", FluxImg2ImgPipeline),
("flux-dev", FluxImg2ImgPipeline),
("flux-fill", FluxImg2ImgPipeline),
("flux-schnell", FluxImg2ImgPipeline),
("hunyuan-video", None),
("inpainting", None), ("inpainting", None),
("inpainting_v2", None), ("inpainting_v2", None),
("controlnet", StableDiffusionControlNetImg2ImgPipeline), ("ltx-video", None),
("v2", StableDiffusionImg2ImgPipeline), ("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
("sd3", StableDiffusion3Img2ImgPipeline),
("sd35_large", StableDiffusion3Img2ImgPipeline),
("sd35_medium", StableDiffusion3Img2ImgPipeline),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", StableDiffusionImg2ImgPipeline), ("v1", StableDiffusionImg2ImgPipeline),
("v2", StableDiffusionImg2ImgPipeline),
("xl_base", StableDiffusionXLImg2ImgPipeline),
("xl_inpaint", None),
("xl_refiner", StableDiffusionXLImg2ImgPipeline),
] ]
) )
SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict( SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict(
[ [
("xl_base", None), ("animatediff_rgb", None),
("xl_refiner", None), ("animatediff_scribble", None),
("xl_inpaint", StableDiffusionXLInpaintPipeline), ("animatediff_sdxl_beta", None),
("playground-v2-5", None), ("animatediff_v1", None),
("upscale", None), ("animatediff_v2", None),
("animatediff_v3", None),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetInpaintPipeline),
("controlnet_xl", None),
("controlnet_xl_large", None),
("controlnet_xl_mid", None),
("controlnet_xl_small", None),
("flux-depth", None),
("flux-dev", None),
("flux-fill", None),
("flux-schnell", None),
("hunyuan-video", None),
("inpainting", StableDiffusionInpaintPipeline), ("inpainting", StableDiffusionInpaintPipeline),
("inpainting_v2", StableDiffusionInpaintPipeline), ("inpainting_v2", StableDiffusionInpaintPipeline),
("controlnet", StableDiffusionControlNetInpaintPipeline), ("ltx-video", None),
("v2", None), ("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", None),
("sd3", None),
("sd35_large", None),
("sd35_medium", None),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", None), ("v1", None),
("v2", None),
("xl_base", None),
("xl_inpaint", StableDiffusionXLInpaintPipeline),
("xl_refiner", None),
] ]
) )
...@@ -116,14 +211,33 @@ CONFIG_FILE_LIST = [ ...@@ -116,14 +211,33 @@ CONFIG_FILE_LIST = [
"diffusion_pytorch_model.non_ema.safetensors", "diffusion_pytorch_model.non_ema.safetensors",
] ]
DIFFUSERS_CONFIG_DIR = ["safety_checker", "unet", "vae", "text_encoder", "text_encoder_2"] DIFFUSERS_CONFIG_DIR = [
"safety_checker",
INPAINT_PIPELINE_KEYS = [ "unet",
"xl_inpaint", "vae",
"inpainting", "text_encoder",
"inpainting_v2", "text_encoder_2",
] ]
TOKENIZER_SHAPE_MAP = {
768: [
"SD 1.4",
"SD 1.5",
"SD 1.5 LCM",
"SDXL 0.9",
"SDXL 1.0",
"SDXL 1.0 LCM",
"SDXL Distilled",
"SDXL Turbo",
"SDXL Lightning",
"PixArt a",
"Playground v2",
"Pony",
],
1024: ["SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "SD 2.1 Unclip"],
}
EXTENSION = [".safetensors", ".ckpt", ".bin"] EXTENSION = [".safetensors", ".ckpt", ".bin"]
CACHE_HOME = os.path.expanduser("~/.cache") CACHE_HOME = os.path.expanduser("~/.cache")
...@@ -162,12 +276,28 @@ class ModelStatus: ...@@ -162,12 +276,28 @@ class ModelStatus:
The name of the model file. The name of the model file.
local (`bool`): local (`bool`):
Whether the model exists locally Whether the model exists locally
site_url (`str`):
The URL of the site where the model is hosted.
""" """
search_word: str = "" search_word: str = ""
download_url: str = "" download_url: str = ""
file_name: str = "" file_name: str = ""
local: bool = False local: bool = False
site_url: str = ""
@dataclass
class ExtraStatus:
r"""
Data class for storing extra status information.
Attributes:
trained_words (`str`):
The words used to trigger the model
"""
trained_words: Union[List[str], None] = None
@dataclass @dataclass
...@@ -191,8 +321,9 @@ class SearchResult: ...@@ -191,8 +321,9 @@ class SearchResult:
model_path: str = "" model_path: str = ""
loading_method: Union[str, None] = None loading_method: Union[str, None] = None
checkpoint_format: Union[str, None] = None checkpoint_format: Union[str, None] = None
repo_status: RepoStatus = RepoStatus() repo_status: RepoStatus = field(default_factory=RepoStatus)
model_status: ModelStatus = ModelStatus() model_status: ModelStatus = field(default_factory=ModelStatus)
extra_status: ExtraStatus = field(default_factory=ExtraStatus)
@validate_hf_hub_args @validate_hf_hub_args
...@@ -385,6 +516,7 @@ def file_downloader( ...@@ -385,6 +516,7 @@ def file_downloader(
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
displayed_filename = kwargs.pop("displayed_filename", None) displayed_filename = kwargs.pop("displayed_filename", None)
# Default mode for file writing and initial file size # Default mode for file writing and initial file size
mode = "wb" mode = "wb"
file_size = 0 file_size = 0
...@@ -396,7 +528,7 @@ def file_downloader( ...@@ -396,7 +528,7 @@ def file_downloader(
if os.path.exists(save_path): if os.path.exists(save_path):
if not force_download: if not force_download:
# If the file exists and force_download is False, skip the download # If the file exists and force_download is False, skip the download
logger.warning(f"File already exists: {save_path}, skipping download.") logger.info(f"File already exists: {save_path}, skipping download.")
return None return None
elif resume: elif resume:
# If resuming, set mode to append binary and get current file size # If resuming, set mode to append binary and get current file size
...@@ -457,10 +589,18 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ...@@ -457,10 +589,18 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
gated = kwargs.pop("gated", False) gated = kwargs.pop("gated", False)
skip_error = kwargs.pop("skip_error", False) skip_error = kwargs.pop("skip_error", False)
file_list = []
hf_repo_info = {}
hf_security_info = {}
model_path = ""
repo_id, file_name = "", ""
diffusers_model_exists = False
# Get the type and loading method for the keyword # Get the type and loading method for the keyword
search_word_status = get_keyword_types(search_word) search_word_status = get_keyword_types(search_word)
if search_word_status["type"]["hf_repo"]: if search_word_status["type"]["hf_repo"]:
hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)
if download: if download:
model_path = DiffusionPipeline.download( model_path = DiffusionPipeline.download(
search_word, search_word,
...@@ -503,13 +643,6 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ...@@ -503,13 +643,6 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
) )
model_dicts = [asdict(value) for value in list(hf_models)] model_dicts = [asdict(value) for value in list(hf_models)]
file_list = []
hf_repo_info = {}
hf_security_info = {}
model_path = ""
repo_id, file_name = "", ""
diffusers_model_exists = False
# Loop through models to find a suitable candidate # Loop through models to find a suitable candidate
for repo_info in model_dicts: for repo_info in model_dicts:
repo_id = repo_info["id"] repo_id = repo_info["id"]
...@@ -523,7 +656,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ...@@ -523,7 +656,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
if hf_security_info["scansDone"]: if hf_security_info["scansDone"]:
for info in repo_info["siblings"]: for info in repo_info["siblings"]:
file_path = info["rfilename"] file_path = info["rfilename"]
if "model_index.json" == file_path and checkpoint_format in ["diffusers", "all"]: if "model_index.json" == file_path and checkpoint_format in [
"diffusers",
"all",
]:
diffusers_model_exists = True diffusers_model_exists = True
break break
...@@ -571,6 +707,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ...@@ -571,6 +707,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
force_download=force_download, force_download=force_download,
) )
# `pathlib.PosixPath` may be returned
if model_path:
model_path = str(model_path)
if file_name: if file_name:
download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}" download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
else: else:
...@@ -586,10 +726,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ...@@ -586,10 +726,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision), repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision),
model_status=ModelStatus( model_status=ModelStatus(
search_word=search_word, search_word=search_word,
site_url=download_url,
download_url=download_url, download_url=download_url,
file_name=file_name, file_name=file_name,
local=download, local=download,
), ),
extra_status=ExtraStatus(trained_words=None),
) )
else: else:
...@@ -605,6 +747,8 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -605,6 +747,8 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
The search query string. The search query string.
model_type (`str`, *optional*, defaults to `Checkpoint`): model_type (`str`, *optional*, defaults to `Checkpoint`):
The type of model to search for. The type of model to search for.
sort (`str`, *optional*):
The order in which you wish to sort the results(for example, `Highest Rated`, `Most Downloaded`, `Newest`).
base_model (`str`, *optional*): base_model (`str`, *optional*):
The base model to filter by. The base model to filter by.
download (`bool`, *optional*, defaults to `False`): download (`bool`, *optional*, defaults to `False`):
...@@ -628,6 +772,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -628,6 +772,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
# Extract additional parameters from kwargs # Extract additional parameters from kwargs
model_type = kwargs.pop("model_type", "Checkpoint") model_type = kwargs.pop("model_type", "Checkpoint")
sort = kwargs.pop("sort", None)
download = kwargs.pop("download", False) download = kwargs.pop("download", False)
base_model = kwargs.pop("base_model", None) base_model = kwargs.pop("base_model", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -642,6 +787,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -642,6 +787,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
repo_name = "" repo_name = ""
repo_id = "" repo_id = ""
version_id = "" version_id = ""
trainedWords = ""
models_list = [] models_list = []
selected_repo = {} selected_repo = {}
selected_model = {} selected_model = {}
...@@ -652,12 +798,16 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -652,12 +798,16 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
params = { params = {
"query": search_word, "query": search_word,
"types": model_type, "types": model_type,
"sort": "Most Downloaded",
"limit": 20, "limit": 20,
} }
if base_model is not None: if base_model is not None:
if not isinstance(base_model, list):
base_model = [base_model]
params["baseModel"] = base_model params["baseModel"] = base_model
if sort is not None:
params["sort"] = sort
headers = {} headers = {}
if token: if token:
headers["Authorization"] = f"Bearer {token}" headers["Authorization"] = f"Bearer {token}"
...@@ -686,25 +836,30 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -686,25 +836,30 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
# Sort versions within the selected repo by download count # Sort versions within the selected repo by download count
sorted_versions = sorted( sorted_versions = sorted(
selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True selected_repo["modelVersions"],
key=lambda x: x["stats"]["downloadCount"],
reverse=True,
) )
for selected_version in sorted_versions: for selected_version in sorted_versions:
version_id = selected_version["id"] version_id = selected_version["id"]
trainedWords = selected_version["trainedWords"]
models_list = [] models_list = []
for model_data in selected_version["files"]: # When searching for textual inversion, results other than the values entered for the base model may come up, so check again.
# Check if the file passes security scans and has a valid extension if base_model is None or selected_version["baseModel"] in base_model:
file_name = model_data["name"] for model_data in selected_version["files"]:
if ( # Check if the file passes security scans and has a valid extension
model_data["pickleScanResult"] == "Success" file_name = model_data["name"]
and model_data["virusScanResult"] == "Success" if (
and any(file_name.endswith(ext) for ext in EXTENSION) model_data["pickleScanResult"] == "Success"
and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR and model_data["virusScanResult"] == "Success"
): and any(file_name.endswith(ext) for ext in EXTENSION)
file_status = { and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
"filename": file_name, ):
"download_url": model_data["downloadUrl"], file_status = {
} "filename": file_name,
models_list.append(file_status) "download_url": model_data["downloadUrl"],
}
models_list.append(file_status)
if models_list: if models_list:
# Sort the models list by filename and find the safest model # Sort the models list by filename and find the safest model
...@@ -764,19 +919,229 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] ...@@ -764,19 +919,229 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id), repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id),
model_status=ModelStatus( model_status=ModelStatus(
search_word=search_word, search_word=search_word,
site_url=f"https://civitai.com/models/{repo_id}?modelVersionId={version_id}",
download_url=download_url, download_url=download_url,
file_name=file_name, file_name=file_name,
local=output_info["type"]["local"], local=output_info["type"]["local"],
), ),
extra_status=ExtraStatus(trained_words=trainedWords or None),
) )
class EasyPipelineForText2Image(AutoPipelineForText2Image): def add_methods(pipeline):
r""" r"""
Add methods from `AutoConfig` to the pipeline.
Parameters:
pipeline (`Pipeline`):
The pipeline to which the methods will be added.
"""
for attr_name in dir(AutoConfig):
attr_value = getattr(AutoConfig, attr_name)
if callable(attr_value) and not attr_name.startswith("__"):
setattr(pipeline, attr_name, types.MethodType(attr_value, pipeline))
return pipeline
class AutoConfig:
def auto_load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str]],
token: Optional[Union[str, List[str]]] = None,
base_model: Optional[Union[str, List[str]]] = None,
tokenizer=None,
text_encoder=None,
**kwargs,
):
r"""
Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
Automatic1111 formats are supported).
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either one of the following or a list of them:
- Search keywords for pretrained model (for example `EasyNegative`).
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
pretrained model hosted on the Hub.
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
inversion weights.
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
If not specified, function will take self.tokenizer.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used when:
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
name such as `text_inv.bin`.
- The saved textual inversion file is in the Automatic1111 format.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
Examples:
```py
>>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
>>> image = pipeline(prompt).images[0]
```
"""
# 1. Set tokenizer and text encoder
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
# Check if tokenizer and text encoder are provided
if tokenizer is None or text_encoder is None:
raise ValueError("Tokenizer and text encoder must be provided.")
# 2. Normalize inputs
pretrained_model_name_or_paths = (
[pretrained_model_name_or_path]
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
# 2.1 Normalize tokens
tokens = [token] if not isinstance(token, list) else token
if tokens[0] is None:
tokens = tokens * len(pretrained_model_name_or_paths)
for check_token in tokens:
# Check if token is already in tokenizer vocabulary
if check_token in tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
expected_shape = text_encoder.get_input_embeddings().weight.shape[-1] # Expected shape of tokenizer
for search_word in pretrained_model_name_or_paths:
if isinstance(search_word, str):
# Update kwargs to ensure the model is downloaded and parameters are included
_status = {
"download": True,
"include_params": True,
"skip_error": False,
"model_type": "TextualInversion",
}
# Get tags for the base model of textual inversion compatible with tokenizer.
# If the tokenizer is 768-dimensional, set tags for SD 1.x and SDXL.
# If the tokenizer is 1024-dimensional, set tags for SD 2.x.
if expected_shape in TOKENIZER_SHAPE_MAP:
# Retrieve the appropriate tags from the TOKENIZER_SHAPE_MAP based on the expected shape
tags = TOKENIZER_SHAPE_MAP[expected_shape]
if base_model is not None:
if isinstance(base_model, list):
tags.extend(base_model)
else:
tags.append(base_model)
_status["base_model"] = tags
kwargs.update(_status)
# Search for the model on Civitai and get the model status
textual_inversion_path = search_civitai(search_word, **kwargs)
logger.warning(
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
)
pretrained_model_name_or_paths[
pretrained_model_name_or_paths.index(search_word)
] = textual_inversion_path.model_path
self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
)
def auto_load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
r"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
[`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
loaded into `self.unet`.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
dict is loaded into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if isinstance(pretrained_model_name_or_path_or_dict, str):
# Update kwargs to ensure the model is downloaded and parameters are included
_status = {
"download": True,
"include_params": True,
"skip_error": False,
"model_type": "LORA",
}
kwargs.update(_status)
# Search for the model on Civitai and get the model status
lora_path = search_civitai(pretrained_model_name_or_path_or_dict, **kwargs)
logger.warning(f"lora_path: {lora_path.model_status.site_url}")
logger.warning(f"trained_words: {lora_path.extra_status.trained_words}")
pretrained_model_name_or_path_or_dict = lora_path.model_path
self.load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
class EasyPipelineForText2Image(AutoPipelineForText2Image):
r"""
[`EasyPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
specific underlying pipeline class is automatically selected from either the specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods. [`~EasyPipelineForText2Image.from_pretrained`], [`~EasyPipelineForText2Image.from_pipe`], [`~EasyPipelineForText2Image.from_huggingface`] or [`~EasyPipelineForText2Image.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
...@@ -891,9 +1256,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image): ...@@ -891,9 +1256,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt).images[0]
``` ```
""" """
...@@ -907,20 +1272,21 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image): ...@@ -907,20 +1272,21 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
kwargs.update(_status) kwargs.update(_status)
# Search for the model on Hugging Face and get the model status # Search for the model on Hugging Face and get the model status
hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = hf_model_status.model_path checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint # Check the format of the model checkpoint
if hf_model_status.checkpoint_format == "single_file": if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
else: else:
return cls.from_pretrained(checkpoint_path, **kwargs) pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod @classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs): def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
...@@ -999,9 +1365,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image): ...@@ -999,9 +1365,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt).images[0]
``` ```
""" """
...@@ -1015,24 +1381,25 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image): ...@@ -1015,24 +1381,25 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
kwargs.update(_status) kwargs.update(_status)
# Search for the model on Civitai and get the model status # Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs) checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = model_status.model_path checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
return add_methods(pipeline)
class EasyPipelineForImage2Image(AutoPipelineForImage2Image): class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
r""" r"""
[`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The [`EasyPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
specific underlying pipeline class is automatically selected from either the specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods. [`~EasyPipelineForImage2Image.from_pretrained`], [`~EasyPipelineForImage2Image.from_pipe`], [`~EasyPipelineForImage2Image.from_huggingface`] or [`~EasyPipelineForImage2Image.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
...@@ -1147,10 +1514,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image): ...@@ -1147,10 +1514,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForImage2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt, image).images[0]
``` ```
""" """
# Update kwargs to ensure the model is downloaded and parameters are included # Update kwargs to ensure the model is downloaded and parameters are included
...@@ -1163,20 +1530,22 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image): ...@@ -1163,20 +1530,22 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
kwargs.update(_parmas) kwargs.update(_parmas)
# Search for the model on Hugging Face and get the model status # Search for the model on Hugging Face and get the model status
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = model_status.model_path checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint # Check the format of the model checkpoint
if model_status.checkpoint_format == "single_file": if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
else: else:
return cls.from_pretrained(checkpoint_path, **kwargs) pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod @classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs): def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
...@@ -1255,10 +1624,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image): ...@@ -1255,10 +1624,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForImage2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt, image).images[0]
``` ```
""" """
# Update kwargs to ensure the model is downloaded and parameters are included # Update kwargs to ensure the model is downloaded and parameters are included
...@@ -1271,24 +1640,25 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image): ...@@ -1271,24 +1640,25 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
kwargs.update(_status) kwargs.update(_status)
# Search for the model on Civitai and get the model status # Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs) checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = model_status.model_path checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
return add_methods(pipeline)
class EasyPipelineForInpainting(AutoPipelineForInpainting): class EasyPipelineForInpainting(AutoPipelineForInpainting):
r""" r"""
[`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The [`EasyPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
specific underlying pipeline class is automatically selected from either the specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods. [`~EasyPipelineForInpainting.from_pretrained`], [`~EasyPipelineForInpainting.from_pipe`], [`~EasyPipelineForInpainting.from_huggingface`] or [`~EasyPipelineForInpainting.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
...@@ -1403,10 +1773,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting): ...@@ -1403,10 +1773,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForInpainting
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
``` ```
""" """
# Update kwargs to ensure the model is downloaded and parameters are included # Update kwargs to ensure the model is downloaded and parameters are included
...@@ -1419,20 +1789,21 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting): ...@@ -1419,20 +1789,21 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
kwargs.update(_status) kwargs.update(_status)
# Search for the model on Hugging Face and get the model status # Search for the model on Hugging Face and get the model status
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = model_status.model_path checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint # Check the format of the model checkpoint
if model_status.checkpoint_format == "single_file": if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
else: else:
return cls.from_pretrained(checkpoint_path, **kwargs) pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod @classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs): def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
...@@ -1511,10 +1882,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting): ...@@ -1511,10 +1882,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Examples: Examples:
```py ```py
>>> from diffusers import AutoPipelineForText2Image >>> from auto_diffusers import EasyPipelineForInpainting
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") >>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
``` ```
""" """
# Update kwargs to ensure the model is downloaded and parameters are included # Update kwargs to ensure the model is downloaded and parameters are included
...@@ -1527,13 +1898,14 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting): ...@@ -1527,13 +1898,14 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
kwargs.update(_status) kwargs.update(_status)
# Search for the model on Civitai and get the model status # Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs) checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = model_status.model_path checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint # Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file( pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path, pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
**kwargs, **kwargs,
) )
return add_methods(pipeline)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment