Unverified Commit beacaa55 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Layerwise Upcasting (#10347)



* update

* update

* make style

* remove dynamo disable

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a6476822
......@@ -60,6 +60,7 @@ class AnimateDiffPipelineFastTests(
"callback_on_step_end_tensor_inputs",
]
)
test_layerwise_casting = True
def get_dummy_components(self):
cross_attention_dim = 8
......
......@@ -30,6 +30,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -58,6 +58,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -55,6 +55,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
]
)
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -56,6 +56,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -58,6 +58,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -126,6 +126,7 @@ class ControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -75,6 +75,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -50,6 +50,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -57,6 +57,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -59,6 +59,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_layerwise_casting = True
def get_dummy_components(
self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False
......
......@@ -139,6 +139,7 @@ class ControlNetXSPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_attention_slicing = False
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
......@@ -78,6 +78,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_attention_slicing = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -31,6 +31,7 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -22,6 +22,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -23,6 +23,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -55,6 +55,7 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -53,6 +53,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -61,6 +61,7 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit
required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
supports_dduf = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -48,6 +48,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
supports_dduf = False
test_layerwise_casting = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment