Commit 1900e511 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix potential issue.

parent 276f8fce
...@@ -625,7 +625,7 @@ class SpatialTransformer(nn.Module): ...@@ -625,7 +625,7 @@ class SpatialTransformer(nn.Module):
x = self.norm(x) x = self.norm(x)
if not self.use_linear: if not self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
x = x.movedim(1, -1).flatten(1, 2).contiguous() x = x.movedim(1, 3).flatten(1, 2).contiguous()
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
...@@ -633,7 +633,7 @@ class SpatialTransformer(nn.Module): ...@@ -633,7 +633,7 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context[i], transformer_options=transformer_options) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous() x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
if not self.use_linear: if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment