Unverified Commit 8a4c3e50 authored by William Held's avatar William Held Committed by GitHub
Browse files

Width was typod as weight (#1800)

* Width was typod as weight

* Run Black
parent 68e24259
...@@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
""" """
# 1. Input # 1. Input
if self.is_input_continuous: if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape batch, channel, height, width = hidden_states.shape
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else: else:
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
...@@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output # 3. Output
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = ( hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
else: else:
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = ( hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual output = hidden_states + residual
elif self.is_input_vectorized: elif self.is_input_vectorized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment