Commit 66d35c07 authored by comfyanonymous's avatar comfyanonymous
Browse files

Improve artifacts on hydit, auraflow and SD3 on specific resolutions.

This breaks seeds for resolutions that are not a multiple of 16 in pixel
resolution by using circular padding instead of reflection padding but
should lower the amount of artifacts when doing img2img at those
resolutions.
parent c75b5060
...@@ -409,7 +409,7 @@ class MMDiT(nn.Module): ...@@ -409,7 +409,7 @@ class MMDiT(nn.Module):
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = x.view( x = x.view(
B, B,
C, C,
......
...@@ -69,12 +69,14 @@ class PatchEmbed(nn.Module): ...@@ -69,12 +69,14 @@ class PatchEmbed(nn.Module):
bias: bool = True, bias: bool = True,
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = True, dynamic_img_pad: bool = True,
padding_mode='circular',
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.patch_size = (patch_size, patch_size) self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode
if img_size is not None: if img_size is not None:
self.img_size = (img_size, img_size) self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
...@@ -110,7 +112,7 @@ class PatchEmbed(nn.Module): ...@@ -110,7 +112,7 @@ class PatchEmbed(nn.Module):
if self.dynamic_img_pad: if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment