Unverified Commit 3be9fa97 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Accelerate model loading] Fix meta device and super low memory usage (#1016)

* [Accelerate model loading] Fix meta device and super low memory usage

* better naming
parent e92a603c
......@@ -119,14 +119,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def cuda_with_minimal_gpu_usage(self):
def enable_sequential_cpu_offload(self):
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
self.enable_attention_slicing(1)
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
cpu_offload(cpu_offloaded_model, device)
......
......@@ -15,6 +15,7 @@
import gc
import random
import time
import unittest
import numpy as np
......@@ -730,3 +731,39 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
)
assert test_callback_fn.has_been_called
assert number_of_steps == 51
def test_stable_diffusion_accelerate_auto_device(self):
pipeline_id = "CompVis/stable-diffusion-v1-4"
start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
normal_load_time = time.time() - start_time
start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
meta_device_load_time = time.time() - start_time
assert 2 * meta_device_load_time < normal_load_time
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload()
_ = pipeline(prompt, num_inference_steps=5)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 1.5 GB is allocated
assert mem_bytes < 1.5 * 10**9
......@@ -17,15 +17,12 @@ import gc
import os
import random
import tempfile
import tracemalloc
import unittest
import numpy as np
import torch
import accelerate
import PIL
import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
......@@ -44,8 +41,7 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from packaging import version
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......@@ -487,71 +483,3 @@ class PipelineSlowTests(unittest.TestCase):
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
@require_torch_gpu
def test_stable_diffusion_accelerate_load_works(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)
@require_torch_gpu
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
pipeline_id = "CompVis/stable-diffusion-v1-4"
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()
del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()
tracemalloc.stop()
assert peak_accelerate < peak_normal
@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"
pipeline = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
)
pipeline.cuda_with_minimal_gpu_usage()
_ = pipeline(prompt)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 0.8 GB is allocated
assert mem_bytes < 0.8 * 10**9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment