"python/dgl/transform/functional.py" did not exist on "6294677f8acc6bc040baf922910473e1c82995ba"
Unverified Commit 18c8f10f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7298bdd8
...@@ -28,8 +28,7 @@ from ..test_pipelines_common import ( ...@@ -28,8 +28,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin, FluxIPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fused_layers_exist,
check_qkv_fusion_processors_exist,
) )
...@@ -171,12 +170,10 @@ class FluxPipelineFastTests( ...@@ -171,12 +170,10 @@ class FluxPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken ...@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import ( from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import ( ...@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_pipelines_common import ( from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin ...@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig ...@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
...@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model): ...@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names) return all(p.startswith("Fused") for p in proc_names)
def check_qkv_fused_layers_exist(model, layer_names):
is_fused_submodules = []
for submodule in model.modules():
if not isinstance(submodule, AttentionModuleMixin):
continue
is_fused_attribute_set = submodule.fused_projections
is_fused_layer = True
for layer in layer_names:
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
is_fused = is_fused_attribute_set and is_fused_layer
is_fused_submodules.append(is_fused)
return all(is_fused_submodules)
class SDFunctionTesterMixin: class SDFunctionTesterMixin:
""" """
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment