Unverified Commit 0160e514 authored by w4ffl35's avatar w4ffl35 Committed by GitHub
Browse files

Adds local_files_only bool to prevent forced online connection (#3486)

parent 194b0a42
...@@ -727,8 +727,8 @@ def convert_ldm_bert_checkpoint(checkpoint, config): ...@@ -727,8 +727,8 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
return hf_model return hf_model
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
...@@ -992,6 +992,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -992,6 +992,7 @@ def download_from_original_stable_diffusion_ckpt(
controlnet: Optional[bool] = None, controlnet: Optional[bool] = None,
load_safety_checker: bool = True, load_safety_checker: bool = True,
pipeline_class: DiffusionPipeline = None, pipeline_class: DiffusionPipeline = None,
local_files_only=False
) -> DiffusionPipeline: ) -> DiffusionPipeline:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
...@@ -1037,6 +1038,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1037,6 +1038,8 @@ def download_from_original_stable_diffusion_ckpt(
Whether to load the safety checker or not. Defaults to `True`. Whether to load the safety checker or not. Defaults to `True`.
pipeline_class (`str`, *optional*, defaults to `None`): pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically. The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
...@@ -1292,7 +1295,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1292,7 +1295,7 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
elif model_type == "FrozenCLIPEmbedder": elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
if load_safety_checker: if load_safety_checker:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment