"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d297cda2c513e747c804cc4d2bd526ecb66349df"
Unverified Commit 2777264e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

`enable_model_cpu_offload` (#2285)

* enable_model_offload PoC

It's surprisingly more involved than expected, see comments in the PR.

* Rename final_offload_hook

* Invoke the vae forward hook manually.

* Completely remove decoder.

* Style

* apply_forward_hook decorator

* Rename method.

* Style

* Copy enable_model_cpu_offload

* Fix copies.

* Remove comment.

* Fix copies

* Missing import

* Fix doc-builder style.

* Merge main and fix again.

* Add docs

* Fix docs.

* Add a couple of tests.

* style
parent 6eaebe82
...@@ -192,7 +192,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -192,7 +192,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks. hooks.
""" """
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): if not hasattr(self.unet, "_hf_hook"):
return self.device return self.device
for module in self.unet.modules(): for module in self.unet.modules():
if ( if (
......
...@@ -170,7 +170,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -170,7 +170,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks. hooks.
""" """
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): if not hasattr(self.image_unet, "_hf_hook"):
return self.device return self.device
for module in self.image_unet.modules(): for module in self.image_unet.modules():
if ( if (
......
...@@ -97,7 +97,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -97,7 +97,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks. hooks.
""" """
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): if not hasattr(self.image_unet, "_hf_hook"):
return self.device return self.device
for module in self.image_unet.modules(): for module in self.image_unet.modules():
if ( if (
......
...@@ -121,7 +121,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -121,7 +121,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks. hooks.
""" """
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): if not hasattr(self.image_unet, "_hf_hook"):
return self.device return self.device
for module in self.image_unet.modules(): for module in self.image_unet.modules():
if ( if (
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from packaging import version from packaging import version
from .. import __version__ from .. import __version__
from .accelerate_utils import apply_forward_hook
from .constants import ( from .constants import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
...@@ -44,6 +45,7 @@ from .import_utils import ( ...@@ -44,6 +45,7 @@ from .import_utils import (
DummyObject, DummyObject,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available, is_k_diffusion_available,
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Accelerate utilities: Utilities related to accelerate
"""
from packaging import version
from .import_utils import is_accelerate_available
if is_accelerate_available():
import accelerate
def apply_forward_hook(method):
"""
Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful
for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the
appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`].
This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.
:param method: The method to decorate. This method should be a method of a PyTorch module.
"""
accelerate_version = version.parse(accelerate.__version__).base_version
if version.parse(accelerate_version) < version.parse("0.17.0"):
return method
def wrapper(self, *args, **kwargs):
if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):
self._hf_hook.pre_forward(self)
return method(self, *args, **kwargs)
return wrapper
...@@ -476,6 +476,20 @@ def is_transformers_version(operation: str, version: str): ...@@ -476,6 +476,20 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
def is_accelerate_version(operation: str, version: str):
"""
Args:
Compares the current Accelerate version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _accelerate_available:
return False
return compare_versions(parse(_accelerate_version), operation, version)
def is_k_diffusion_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str):
""" """
Args: Args:
......
...@@ -789,6 +789,59 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -789,6 +789,59 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# make sure that less than 2.8 GB is allocated # make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9 assert mem_bytes < 2.8 * 10**9
def test_stable_diffusion_pipeline_with_model_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
inputs = self.get_inputs(torch_device, dtype=torch.float16)
# Normal inference
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# With model offloading
# Reload but don't move to cuda
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
outputs_offloaded = pipe(**inputs)
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3
assert mem_bytes_offloaded < mem_bytes
assert mem_bytes_offloaded < 3.5 * 10**9
for module in pipe.text_encoder, pipe.unet, pipe.vae, pipe.safety_checker:
assert module.device == torch.device("cpu")
# With attention slicing
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.enable_attention_slicing()
_ = pipe(**inputs)
mem_bytes_slicing = torch.cuda.max_memory_allocated()
assert mem_bytes_slicing < mem_bytes_offloaded
assert mem_bytes_slicing < 3 * 10**9
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -342,6 +342,47 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -342,6 +342,47 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated # make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9 assert mem_bytes < 2.2 * 10**9
def test_stable_diffusion_pipeline_with_model_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
inputs = self.get_inputs(torch_device, dtype=torch.float16)
# Normal inference
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# With model offloading
# Reload but don't move to cuda
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
torch_dtype=torch.float16,
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
_ = pipe(**inputs)
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
assert mem_bytes_offloaded < mem_bytes
for module in pipe.text_encoder, pipe.unet, pipe.vae:
assert module.device == torch.device("cpu")
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self): def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
init_image = load_image( init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
......
...@@ -393,6 +393,57 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): ...@@ -393,6 +393,57 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
# make sure that less than 2.8 GB is allocated # make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9 assert mem_bytes < 2.8 * 10**9
def test_stable_diffusion_pipeline_with_model_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
inputs = self.get_inputs(torch_device, dtype=torch.float16)
# Normal inference
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# With model offloading
# Reload but don't move to cuda
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
outputs_offloaded = pipe(**inputs)
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3
assert mem_bytes_offloaded < mem_bytes
assert mem_bytes_offloaded < 3 * 10**9
for module in pipe.text_encoder, pipe.unet, pipe.vae:
assert module.device == torch.device("cpu")
# With attention slicing
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.enable_attention_slicing()
_ = pipe(**inputs)
mem_bytes_slicing = torch.cuda.max_memory_allocated()
assert mem_bytes_slicing < mem_bytes_offloaded
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment