"vscode:/vscode.git/clone" did not exist on "9c5bf342bc34f94de9aa4a171d726e6b341a91e6"
Unverified Commit ae672d58 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Tests] Lower required memory for clip guided and fix super edge-case git...

[Tests] Lower required memory for clip guided and fix super edge-case git pipeline module bug (#754)

* [Tests] Lower required memory

* fix

* up

* uP
parent 2fa55fc7
...@@ -259,7 +259,8 @@ def get_cached_module_file( ...@@ -259,7 +259,8 @@ def get_cached_module_file(
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=False, use_auth_token=False,
) )
submodule = "local" submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
except EnvironmentError: except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise raise
...@@ -288,7 +289,7 @@ def get_cached_module_file( ...@@ -288,7 +289,7 @@ def get_cached_module_file(
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule) create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local": if submodule == "local" or submodule == "git":
# We always copy local files (we could hash the file to see if there was a change, and give them the name of # We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now). # that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path. # The only reason we do the copy is to avoid putting too many folders in sys.path.
......
...@@ -112,18 +112,22 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -112,18 +112,22 @@ class CustomPipelineTests(unittest.TestCase):
assert output_str == "This is a local test" assert output_str == "This is a local test"
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_load_pipeline_from_git(self): def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id) clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
custom_pipeline="clip_guided_stable_diffusion", custom_pipeline="clip_guided_stable_diffusion",
clip_model=clip_model, clip_model=clip_model,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
torch_dtype=torch.float16,
revision="fp16",
) )
pipeline.enable_attention_slicing()
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
# NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under: # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment