Unverified Commit d2df40c6 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add VAE tiling option for SD3 (#8791)

update
parent 2261510b
...@@ -360,7 +360,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -360,7 +360,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for j in range(0, x.shape[3], overlap_size): for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile) tile = self.encoder(tile)
tile = self.quant_conv(tile) if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile) row.append(tile)
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
...@@ -409,7 +410,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -409,7 +410,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
row = [] row = []
for j in range(0, z.shape[3], overlap_size): for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile) if self.config.use_post_quant_conv:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile) decoded = self.decoder(tile)
row.append(decoded) row.append(decoded)
rows.append(row) rows.append(row)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment