Unverified Commit d886e497 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `SpatialTransformer` (#578)



* Fix SpatialTransformer

* Fix SpatialTransformer
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent ab3fd671
......@@ -144,10 +144,11 @@ class SpatialTransformer(nn.Module):
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment