Unverified Commit 9bef9f4b authored by Jianbing Wu's avatar Jianbing Wu Committed by GitHub
Browse files

Fix SVD bug (shape of `time_context`) (#7268)



* Fix SVD bug (shape of `time_context`)

* Formatting code

* Formatting src/diffusers/models/transformers/transformer_temporal.py by `make style && make quality`

---------
Co-authored-by: default avatarkevinkhwu <kevinkhwu@tencent.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7aa45142
...@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module): ...@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
time_context_first_timestep = time_context[None, :].reshape( time_context_first_timestep = time_context[None, :].reshape(
batch_size, num_frames, -1, time_context.shape[-1] batch_size, num_frames, -1, time_context.shape[-1]
)[:, 0] )[:, 0]
time_context = time_context_first_timestep[None, :].broadcast_to( time_context = time_context_first_timestep[:, None].broadcast_to(
height * width, batch_size, 1, time_context.shape[-1] batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
) )
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
residual = hidden_states residual = hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment