Unverified Commit 354d35ad authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

bugfix: fix chrono-edit context parallel (#12660)



* bugfix: fix chrono-edit context parallel

* bugfix: fix chrono-edit context parallel

* Update src/diffusers/models/transformers/transformer_chronoedit.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/models/transformers/transformer_chronoedit.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Clean up comments in transformer_chronoedit.py

Removed unnecessary comments regarding parallelization in cross-attention.

* fix style

* fix qc

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 544ba677
...@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t ...@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
return key_img, value_img return key_img, value_img
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor # modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
class WanAttnProcessor: class WanAttnProcessor:
_attention_backend = None _attention_backend = None
_parallel_config = None _parallel_config = None
...@@ -137,7 +137,8 @@ class WanAttnProcessor: ...@@ -137,7 +137,8 @@ class WanAttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config, # Reference: https://github.com/huggingface/diffusers/pull/12660
parallel_config=None,
) )
hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query) hidden_states_img = hidden_states_img.type_as(query)
...@@ -150,7 +151,8 @@ class WanAttnProcessor: ...@@ -150,7 +151,8 @@ class WanAttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config, # Reference: https://github.com/huggingface/diffusers/pull/12660
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query) hidden_states = hidden_states.type_as(query)
...@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel( ...@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
"blocks.0": { "blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
}, },
"blocks.*": { # Reference: https://github.com/huggingface/diffusers/pull/12660
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), # We need to disable the splitting of encoder_hidden_states because
}, # the image_encoder consistently generates 257 tokens for image_embed. This causes
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
# after concatenation—to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment