Unverified Commit beacaa55 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Layerwise Upcasting (#10347)



* update

* update

* make style

* remove dynamo disable

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a6476822
...@@ -52,6 +52,7 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -52,6 +52,7 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -46,6 +46,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -46,6 +46,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -32,6 +32,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ...@@ -32,6 +32,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
batch_params = frozenset(["prompt", "negative_prompt"]) batch_params = frozenset(["prompt", "negative_prompt"])
supports_dduf = False supports_dduf = False
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -55,6 +55,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -55,6 +55,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -55,6 +55,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr ...@@ -55,6 +55,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
"callback_on_step_end_tensor_inputs", "callback_on_step_end_tensor_inputs",
] ]
) )
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
cross_attention_dim = 8 cross_attention_dim = 8
......
...@@ -50,6 +50,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -50,6 +50,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -55,6 +55,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -55,6 +55,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -52,6 +52,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -52,6 +52,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -123,6 +123,7 @@ class StableDiffusionPipelineFastTests( ...@@ -123,6 +123,7 @@ class StableDiffusionPipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
cross_attention_dim = 8 cross_attention_dim = 8
......
...@@ -75,6 +75,7 @@ class StableDiffusion2PipelineFastTests( ...@@ -75,6 +75,7 @@ class StableDiffusion2PipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -35,6 +35,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -35,6 +35,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
] ]
) )
batch_params = frozenset(["prompt", "negative_prompt"]) batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -75,6 +75,7 @@ class StableDiffusionXLPipelineFastTests( ...@@ -75,6 +75,7 @@ class StableDiffusionXLPipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -987,7 +987,7 @@ class PipelineTesterMixin: ...@@ -987,7 +987,7 @@ class PipelineTesterMixin:
test_attention_slicing = True test_attention_slicing = True
test_xformers_attention = True test_xformers_attention = True
test_layerwise_casting = False
supports_dduf = True supports_dduf = True
def get_generator(self, seed): def get_generator(self, seed):
...@@ -2027,6 +2027,21 @@ class PipelineTesterMixin: ...@@ -2027,6 +2027,21 @@ class PipelineTesterMixin:
elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor):
assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
def test_layerwise_casting_inference(self):
if not self.test_layerwise_casting:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device, dtype=torch.bfloat16)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
denoiser.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment