Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c72e3430
Unverified
Commit
c72e3430
authored
Aug 12, 2022
by
Suraj Patil
Committed by
GitHub
Aug 12, 2022
Browse files
[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs (#167)
use inspect instead of unused kwargs
parent
3228eb16
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+8
-1
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+9
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+0
-1
No files found.
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
c72e3430
import
inspect
from
typing
import
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -59,6 +60,12 @@ class LDMTextToImagePipeline(DiffusionPipeline):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_kwrags
=
{}
if
not
accepts_eta
:
extra_kwrags
[
"eta"
]
=
eta
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
if
guidance_scale
==
1.0
:
# guidance_scale of 1 means no guidance
...
...
@@ -79,7 +86,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_prediction_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
eta
=
eta
)[
"prev_sample"
]
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_kwrags
)[
"prev_sample"
]
# scale and decode the image latents with vae
latents
=
1
/
0.18215
*
latents
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
c72e3430
import
inspect
import
torch
from
tqdm.auto
import
tqdm
...
...
@@ -31,11 +33,17 @@ class LDMPipeline(DiffusionPipeline):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_kwrags
=
{}
if
not
accepts_eta
:
extra_kwrags
[
"eta"
]
=
eta
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# predict the noise residual
noise_prediction
=
self
.
unet
(
latents
,
t
)[
"sample"
]
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_prediction
,
t
,
latents
,
eta
)[
"prev_sample"
]
latents
=
self
.
scheduler
.
step
(
noise_prediction
,
t
,
latents
,
**
extra_kwrags
)[
"prev_sample"
]
# decode the image latents with the VAE
image
=
self
.
vqvae
.
decode
(
latents
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
c72e3430
...
...
@@ -116,7 +116,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
**
kwargs
,
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
):
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment