Unverified Commit 15241225 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Transformer2DModel] don't norm twice (#1381)

don't norm twice
parent f07a16e0
...@@ -201,13 +201,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -201,13 +201,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else: else:
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment