Unverified Commit 98c9aac1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL] Fix all sequential offload (#4010)

* Fix all sequential offload

* make style

* make style
parent e3d71ad8
......@@ -176,7 +176,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin):
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......@@ -196,10 +195,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin):
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder_2, self.vae]:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload
if self.text_encoder is not None:
cpu_offload(self.text_encoder, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
......@@ -22,12 +21,11 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProject
from diffusers import (
AutoencoderKL,
DiffusionPipeline,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
......@@ -190,38 +188,31 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components).to(torch_device)
pipes.append(sd_pipe)
@slow
@require_torch_gpu
class StableDiffusionXLPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
pipe.unet.set_default_attn_processor()
def test_stable_diffusion_default_euler(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
assert np.abs(image_slice - expected_slice).max() < 7e-3
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
......@@ -23,12 +22,11 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProject
from diffusers import (
AutoencoderKL,
DiffusionPipeline,
EulerDiscreteScheduler,
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, slow, torch_device
from diffusers.utils import floats_tensor, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import (
......@@ -205,38 +203,31 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipes.append(sd_pipe)
@slow
@require_torch_gpu
class StableDiffusionXLImg2ImgPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
pipe.unet.set_default_attn_processor()
def test_stable_diffusion_default_euler(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
assert np.abs(image_slice - expected_slice).max() < 7e-3
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment