Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6281d206
"vscode:/vscode.git/clone" did not exist on "94ecb8eb968f6497bb696d0746455607f18353dd"
Unverified
Commit
6281d206
authored
Sep 25, 2023
by
Carson Katri
Committed by
GitHub
Sep 25, 2023
Browse files
Add callbacks to `WuerstchenDecoderPipeline` and `WuerstchenCombinedPipeline` (#5154)
parent
28254c79
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
2 deletions
+33
-2
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+13
-2
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
...sers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+20
-0
No files found.
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
View file @
6281d206
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
...
@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback_steps
:
int
=
1
,
):
):
"""
"""
Function invoked when calling the pipeline for generation.
Function invoked when calling the pipeline for generation.
...
@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
...
@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`).
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Examples:
...
@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
...
@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents
=
self
.
prepare_latents
(
latent_features_shape
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
)
latents
=
self
.
prepare_latents
(
latent_features_shape
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
)
# 6. Run denoising loop
# 6. Run denoising loop
for
t
in
self
.
progress_bar
(
timesteps
[:
-
1
]):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
[:
-
1
])
)
:
ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
effnet
=
(
effnet
=
(
torch
.
cat
([
image_embeddings
,
torch
.
zeros_like
(
image_embeddings
)])
torch
.
cat
([
image_embeddings
,
torch
.
zeros_like
(
image_embeddings
)])
...
@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
...
@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
generator
=
generator
,
generator
=
generator
,
).
prev_sample
).
prev_sample
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
# 10. Scale and decode the image latents with vq-vae
# 10. Scale and decode the image latents with vq-vae
latents
=
self
.
vqgan
.
config
.
scale_factor
*
latents
latents
=
self
.
vqgan
.
config
.
scale_factor
*
latents
images
=
self
.
vqgan
.
decode
(
latents
).
sample
.
clamp
(
0
,
1
)
images
=
self
.
vqgan
.
decode
(
latents
).
sample
.
clamp
(
0
,
1
)
...
...
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
View file @
6281d206
...
@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
...
@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
prior_callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
prior_callback_steps
:
int
=
1
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback_steps
:
int
=
1
,
):
):
"""
"""
Function invoked when calling the pipeline for generation.
Function invoked when calling the pipeline for generation.
...
@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
...
@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`).
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback (`Callable`, *optional*):
A function that will be called every `prior_callback_steps` steps during inference. The function will be
called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`.
prior_callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Examples:
...
@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
...
@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents
=
latents
,
latents
=
latents
,
output_type
=
"pt"
,
output_type
=
"pt"
,
return_dict
=
False
,
return_dict
=
False
,
callback
=
prior_callback
,
callback_steps
=
prior_callback_steps
,
)
)
image_embeddings
=
prior_outputs
[
0
]
image_embeddings
=
prior_outputs
[
0
]
...
@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
...
@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
generator
=
generator
,
generator
=
generator
,
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback
=
callback
,
callback_steps
=
callback_steps
,
)
)
return
outputs
return
outputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment