Unverified Commit a9cb08af authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled (#12562)



* fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>

* address review comment
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* refine
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

---------
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 9f669e7b
...@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook): ...@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook):
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError( logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
) )
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) return x
else:
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook): class ContextParallelGatherHook(ModelHook):
......
...@@ -555,6 +555,9 @@ class WanTransformer3DModel( ...@@ -555,6 +555,9 @@ class WanTransformer3DModel(
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
}, },
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
} }
@register_to_config @register_to_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment