Unverified Commit fffd964a authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

fix FLUX.2 context parallel (#12737)

parent 859b8090
...@@ -676,8 +676,8 @@ class Flux2Transformer2DModel( ...@@ -676,8 +676,8 @@ class Flux2Transformer2DModel(
"": { "": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
}, },
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment