Unverified Commit fab17528 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Low CPU memory] + device map (#772)



* add accelerate to load models with smaller memory footprint

* remove low_cpu_mem_usage as it is reduntant

* move accelerate init weights context to modelling utils

* add test to ensure results are the same when loading with accelerate

* add tests to ensure ram usage gets lower when using accelerate

* move accelerate logic to single snippet under modelling utils and remove it from configuration utils

* format code using to pass quality check

* fix imports with isor

* add accelerate to test extra deps

* only import accelerate if device_map is set to auto

* move accelerate availability check to diffusers import utils

* format code

* add device map to pipeline abstraction

* lint it to pass PR quality check

* fix class check to use accelerate when using diffusers ModelMixin subclasses

* use low_cpu_mem_usage in transformers if device_map is not available

* NoModuleLayer

* comment out tests

* up

* uP

* finish

* Update src/diffusers/pipelines/stable_diffusion/safety_checker.py

* finish

* uP

* make style
Co-authored-by: default avatarPi Esposito <piero.skywalker@gmail.com>
parent feaa7324
...@@ -32,7 +32,19 @@ from tqdm.auto import tqdm ...@@ -32,7 +32,19 @@ from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
is_transformers_available,
logging,
)
if is_transformers_available():
from transformers import PreTrainedModel
INDEX_FILE = "diffusion_pytorch_model.bin" INDEX_FILE = "diffusion_pytorch_model.bin"
...@@ -338,6 +350,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -338,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None) sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -463,6 +476,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -463,6 +476,13 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["provider"] = provider loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options loading_kwargs["sess_options"] = sess_options
if (
issubclass(class_obj, diffusers.ModelMixin)
or is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
):
loading_kwargs["device_map"] = device_map
# check if the module is in a subdirectory # check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)): if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
......
...@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds): ...@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker(PreTrainedModel): class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig): def __init__(self, config: CLIPConfig):
super().__init__(config) super().__init__(config)
...@@ -28,8 +30,8 @@ class StableDiffusionSafetyChecker(PreTrainedModel): ...@@ -28,8 +30,8 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.register_buffer("concept_embeds_weights", torch.ones(17)) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.register_buffer("special_care_embeds_weights", torch.ones(3)) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad() @torch.no_grad()
def forward(self, clip_input, images): def forward(self, clip_input, images):
......
...@@ -17,12 +17,15 @@ import gc ...@@ -17,12 +17,15 @@ import gc
import os import os
import random import random
import tempfile import tempfile
import tracemalloc
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
import accelerate
import PIL import PIL
import transformers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMPipeline, DDIMPipeline,
...@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir from diffusers.utils.testing_utils import get_tests_dir
from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -2034,3 +2038,53 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -2034,3 +2038,53 @@ class PipelineTesterMixin(unittest.TestCase):
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 6 assert number_of_steps == 6
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_works(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)
@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
pipeline_id = "CompVis/stable-diffusion-v1-4"
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()
del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()
tracemalloc.stop()
assert peak_accelerate < peak_normal
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment