Unverified Commit 4191ddee authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Revert revert and install accelerate main (#4963)

* Revert "Temp Revert "[Core] better support offloading when side loading is enabled… (#4927)"

This reverts commit 2ab17049.

* tests: install accelerate from main
parent 2ab17049
...@@ -67,6 +67,7 @@ jobs: ...@@ -67,6 +67,7 @@ jobs:
run: | run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test] python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
- name: Environment - name: Environment
run: | run: |
......
...@@ -63,6 +63,7 @@ jobs: ...@@ -63,6 +63,7 @@ jobs:
run: | run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test] python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
- name: Environment - name: Environment
run: | run: |
......
...@@ -40,7 +40,7 @@ jobs: ...@@ -40,7 +40,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio ${CONDA_RUN} python -m pip install torch torchvision torchaudio
${CONDA_RUN} python -m pip install accelerate --upgrade ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m pip install transformers --upgrade ${CONDA_RUN} python -m pip install transformers --upgrade
- name: Environment - name: Environment
......
...@@ -45,6 +45,7 @@ if is_transformers_available(): ...@@ -45,6 +45,7 @@ if is_transformers_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -778,6 +779,21 @@ class TextualInversionLoaderMixin: ...@@ -778,6 +779,21 @@ class TextualInversionLoaderMixin:
f" `{self.load_textual_inversion.__name__}`" f" `{self.load_textual_inversion.__name__}`"
) )
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
...@@ -931,6 +947,12 @@ class TextualInversionLoaderMixin: ...@@ -931,6 +947,12 @@ class TextualInversionLoaderMixin:
for token_id, embedding in token_ids_and_embeddings: for token_id, embedding in token_ids_and_embeddings:
text_encoder.get_input_embeddings().weight.data[token_id] = embedding text_encoder.get_input_embeddings().weight.data[token_id] = embedding
# offload back
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
class LoraLoaderMixin: class LoraLoaderMixin:
r""" r"""
...@@ -962,6 +984,21 @@ class LoraLoaderMixin: ...@@ -962,6 +984,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recurive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
...@@ -971,6 +1008,12 @@ class LoraLoaderMixin: ...@@ -971,6 +1008,12 @@ class LoraLoaderMixin:
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
def lora_state_dict( def lora_state_dict(
cls, cls,
......
...@@ -1549,6 +1549,26 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -1549,6 +1549,26 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
...@@ -1576,6 +1596,12 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -1576,6 +1596,12 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights( def save_lora_weights(
......
...@@ -1216,6 +1216,26 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1216,6 +1216,26 @@ class StableDiffusionXLControlNetPipeline(
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
...@@ -1243,6 +1263,12 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1243,6 +1263,12 @@ class StableDiffusionXLControlNetPipeline(
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights( def save_lora_weights(
......
...@@ -922,6 +922,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -922,6 +922,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
...@@ -949,6 +969,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -949,6 +969,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
self, self,
......
...@@ -1072,6 +1072,26 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1072,6 +1072,26 @@ class StableDiffusionXLImg2ImgPipeline(
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
...@@ -1099,6 +1119,12 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1099,6 +1119,12 @@ class StableDiffusionXLImg2ImgPipeline(
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights( def save_lora_weights(
......
...@@ -1392,6 +1392,26 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1392,6 +1392,26 @@ class StableDiffusionXLInpaintPipeline(
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
...@@ -1419,6 +1439,12 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1419,6 +1439,12 @@ class StableDiffusionXLInpaintPipeline(
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights( def save_lora_weights(
......
...@@ -1081,6 +1081,42 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -1081,6 +1081,42 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3)) self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_a1111_with_sequential_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_kohya_sd_v15_with_higher_dimensions(self): def test_kohya_sd_v15_with_higher_dimensions(self):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
...@@ -1257,10 +1293,10 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -1257,10 +1293,10 @@ class LoraIntegrationTests(unittest.TestCase):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe( images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
...@@ -1411,3 +1447,21 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -1411,3 +1447,21 @@ class LoraIntegrationTests(unittest.TestCase):
assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
...@@ -1019,6 +1019,56 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -1019,6 +1019,56 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
max_diff = np.abs(expected_image - image).max() max_diff = np.abs(expected_image - image).max()
assert max_diff < 8e-1 assert max_diff < 8e-1
def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.enable_model_cpu_offload()
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
a111_file_neg = hf_hub_download(
"hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
)
pipe.load_textual_inversion(a111_file)
pipe.load_textual_inversion(a111_file_neg)
generator = torch.Generator(device="cpu").manual_seed(1)
prompt = "An logo of a turtle in strong Style-Winter with <low-poly-hd-logos-icons>"
neg_prompt = "Style-Winter-neg"
image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 8e-1
def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.enable_sequential_cpu_offload()
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
a111_file_neg = hf_hub_download(
"hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
)
pipe.load_textual_inversion(a111_file)
pipe.load_textual_inversion(a111_file_neg)
generator = torch.Generator(device="cpu").manual_seed(1)
prompt = "An logo of a turtle in strong Style-Winter with <low-poly-hd-logos-icons>"
neg_prompt = "Style-Winter-neg"
image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 8e-1
@require_torch_2 @require_torch_2
def test_stable_diffusion_compile(self): def test_stable_diffusion_compile(self):
seed = 0 seed = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment