Unverified Commit 9a147b82 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Module Group Offloading (#10503)



* update

* fix

* non_blocking; handle parameters and buffers

* update

* Group offloading with cuda stream prefetching (#10516)

* cuda stream prefetch

* remove breakpoints

* update

* copy model hook implementation from pab

* update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite

* more workarounds to make it actually work

* cleanup

* rewrite

* update

* make sure to sync current stream before overwriting with pinned params

not doing so will lead to erroneous computations on the GPU and cause bad results

* better check

* update

* remove hook implementation to not deal with merge conflict

* re-add hook changes

* why use more memory when less memory do trick

* why still use slightly more memory when less memory do trick

* optimise

* add model tests

* add pipeline tests

* update docs

* add layernorm and groupnorm

* address review comments

* improve tests; add docs

* improve docs

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* apply suggestions from code review

* update tests

* apply suggestions from review

* enable_group_offloading -> enable_group_offload for naming consistency

* raise errors if multiple offloading strategies used; add relevant tests

* handle .to() when group offload applied

* refactor some repeated code

* remove unintentional change from merge conflict

* handle .cuda()

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ab428207
......@@ -59,6 +59,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -127,6 +127,7 @@ class ControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -76,6 +76,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -51,6 +51,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -60,6 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(
self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False
......
......@@ -140,6 +140,7 @@ class ControlNetXSPipelineFastTests(
test_attention_slicing = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -79,6 +79,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
test_attention_slicing = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -35,6 +35,7 @@ class FluxPipelineFastTests(
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
......
......@@ -23,6 +23,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -24,6 +24,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -54,6 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
......
......@@ -54,6 +54,7 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
......
......@@ -47,6 +47,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -33,6 +33,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
supports_dduf = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -56,6 +56,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -56,6 +56,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
]
)
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
cross_attention_dim = 8
......
......@@ -51,6 +51,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -56,6 +56,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -53,6 +53,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -124,6 +124,7 @@ class StableDiffusionPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
cross_attention_dim = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment