Unverified Commit 4fc70848 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Fix a dimension bug in Transform2d (#2144)

The dimension does not match when `inner_dim` is not equal to `in_channels`.
parent 9213d81b
......@@ -200,7 +200,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continous projections
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment