Unverified Commit 4fc70848 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Fix a dimension bug in Transform2d (#2144)

The dimension does not match when `inner_dim` is not equal to `in_channels`.
parent 9213d81b
...@@ -200,7 +200,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -200,7 +200,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continous projections # TODO: should use out_channels for continous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim) self.proj_out = nn.Linear(inner_dim, in_channels)
else: else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment