Unverified Commit 577bd8ea authored by SMG's avatar SMG Committed by GitHub
Browse files

refactor: remove redundant tensor concatenation and slicing in forward pass

parent 8d968511
...@@ -136,8 +136,6 @@ def pulid_forward( ...@@ -136,8 +136,6 @@ def pulid_forward(
controlnet_block_samples=controlnet_block_samples, controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples, controlnet_single_block_samples=controlnet_single_block_samples,
) )
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states) output = self.proj_out(hidden_states)
......
...@@ -982,8 +982,6 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -982,8 +982,6 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
controlnet_block_samples=controlnet_block_samples, controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples, controlnet_single_block_samples=controlnet_single_block_samples,
) )
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states) output = self.proj_out(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment