Unverified Commit 493228a7 authored by Edward Li's avatar Edward Li Committed by GitHub
Browse files

Fix `AutoencoderTiny` with `use_slicing` (#6850)

* Fix `AutoencoderTiny` with `use_slicing`

When using slicing with AutoencoderTiny, the encoder mistakenly encodes the entire batch for every image in the batch.

* Fixed formatting issue
parent 8bf046b7
......@@ -292,7 +292,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
]
output = torch.cat(output)
else:
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment