Unverified Commit 18c8f10f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7298bdd8
......@@ -28,8 +28,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
check_qkv_fused_layers_exist,
)
......@@ -171,12 +170,10 @@ class FluxPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
......
......@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
......@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
......
......@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
......@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
......
......@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
......@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names)
def check_qkv_fused_layers_exist(model, layer_names):
is_fused_submodules = []
for submodule in model.modules():
if not isinstance(submodule, AttentionModuleMixin):
continue
is_fused_attribute_set = submodule.fused_projections
is_fused_layer = True
for layer in layer_names:
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
is_fused = is_fused_attribute_set and is_fused_layer
is_fused_submodules.append(is_fused)
return all(is_fused_submodules)
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment