"vscode:/vscode.git/clone" did not exist on "ffc3bb5806c8b06ff299c85063f7728f6ec3c733"
Unverified Commit 9a147b82 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Module Group Offloading (#10503)



* update

* fix

* non_blocking; handle parameters and buffers

* update

* Group offloading with cuda stream prefetching (#10516)

* cuda stream prefetch

* remove breakpoints

* update

* copy model hook implementation from pab

* update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite

* more workarounds to make it actually work

* cleanup

* rewrite

* update

* make sure to sync current stream before overwriting with pinned params

not doing so will lead to erroneous computations on the GPU and cause bad results

* better check

* update

* remove hook implementation to not deal with merge conflict

* re-add hook changes

* why use more memory when less memory do trick

* why still use slightly more memory when less memory do trick

* optimise

* add model tests

* add pipeline tests

* update docs

* add layernorm and groupnorm

* address review comments

* improve tests; add docs

* improve docs

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* apply suggestions from code review

* update tests

* apply suggestions from review

* enable_group_offloading -> enable_group_offload for naming consistency

* raise errors if multiple offloading strategies used; add relevant tests

* handle .to() when group offload applied

* refactor some repeated code

* remove unintentional change from merge conflict

* handle .cuda()

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ab428207
...@@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests( ...@@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
) )
batch_params = frozenset(["prompt", "negative_prompt"]) batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests( ...@@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -29,6 +29,7 @@ from diffusers import ( ...@@ -29,6 +29,7 @@ from diffusers import (
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
...@@ -47,6 +48,7 @@ from diffusers.utils.testing_utils import ( ...@@ -47,6 +48,7 @@ from diffusers.utils.testing_utils import (
require_accelerator, require_accelerator,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch, require_torch,
require_torch_gpu,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
...@@ -990,6 +992,7 @@ class PipelineTesterMixin: ...@@ -990,6 +992,7 @@ class PipelineTesterMixin:
test_xformers_attention = True test_xformers_attention = True
test_layerwise_casting = False test_layerwise_casting = False
test_group_offloading = False
supports_dduf = True supports_dduf = True
def get_generator(self, seed): def get_generator(self, seed):
...@@ -2044,6 +2047,79 @@ class PipelineTesterMixin: ...@@ -2044,6 +2047,79 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0] _ = pipe(**inputs)[0]
@require_torch_gpu
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"transformer",
"unet",
"controlnet",
]:
if not hasattr(pipe, component_name):
continue
component = getattr(pipe, component_name)
if not getattr(component, "_supports_group_offloading", True):
continue
if hasattr(component, "enable_group_offload"):
# For diffusers ModelMixin implementations
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
else:
# For other models not part of diffusers
apply_group_offloading(
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
for component_name in ["vae", "vqvae"]:
if hasattr(pipe, component_name):
getattr(pipe, component_name).to(torch_device)
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment