Unverified Commit 3bef6f22 authored by SMG's avatar SMG Committed by GitHub
Browse files

fix : fbcache with controlNet (#360)

parent da7059b5
...@@ -390,19 +390,21 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -390,19 +390,21 @@ class FluxCachedTransformerBlocks(nn.Module):
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype).to(self.device) hidden_states = hidden_states.to(self.dtype).to(original_device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device) encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(original_device)
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(original_device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(original_device)
if controlnet_block_samples is not None: if controlnet_block_samples is not None:
controlnet_block_samples = ( controlnet_block_samples = (
torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None torch.stack(controlnet_block_samples).to(original_device) if len(controlnet_block_samples) > 0 else None
)
if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0:
controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(original_device)
if len(controlnet_single_block_samples) > 0
else None
) )
if controlnet_single_block_samples:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
else:
controlnet_single_block_samples = None
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment