Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9a92b817
Unverified
Commit
9a92b817
authored
Oct 30, 2024
by
Aryan
Committed by
GitHub
Oct 30, 2024
Browse files
Allegro VAE fix (#9811)
fix
parent
0d1d267b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
8 deletions
+2
-8
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+2
-8
No files found.
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
View file @
9a92b817
...
@@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
...
@@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_posterior
:
bool
=
False
,
sample_posterior
:
bool
=
False
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
encoder_local_batch_size
:
int
=
2
,
decoder_local_batch_size
:
int
=
2
,
)
->
Union
[
DecoderOutput
,
torch
.
Tensor
]:
)
->
Union
[
DecoderOutput
,
torch
.
Tensor
]:
r
"""
r
"""
Args:
Args:
...
@@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
...
@@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
"""
x
=
sample
x
=
sample
posterior
=
self
.
encode
(
x
,
local_batch_size
=
encoder_local_batch_size
).
latent_dist
posterior
=
self
.
encode
(
x
).
latent_dist
if
sample_posterior
:
if
sample_posterior
:
z
=
posterior
.
sample
(
generator
=
generator
)
z
=
posterior
.
sample
(
generator
=
generator
)
else
:
else
:
z
=
posterior
.
mode
()
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
,
local_batch_size
=
decoder_local_batch_size
).
sample
dec
=
self
.
decode
(
z
).
sample
if
not
return_dict
:
if
not
return_dict
:
return
(
dec
,)
return
(
dec
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment