Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b24f7834
Unverified
Commit
b24f7834
authored
Mar 07, 2024
by
pravdomil
Committed by
GitHub
Mar 07, 2024
Browse files
use self.device (#6595)
* use self.device * use device * fix * fix
parent
3ce905c9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
5 deletions
+7
-5
examples/community/rerender_a_video.py
examples/community/rerender_a_video.py
+7
-5
No files found.
examples/community/rerender_a_video.py
View file @
b24f7834
...
...
@@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5)
@
torch
.
no_grad
()
def
get_warped_and_mask
(
flow_model
,
image1
,
image2
,
image3
=
None
,
pixel_consistency
=
False
):
def
get_warped_and_mask
(
flow_model
,
image1
,
image2
,
image3
=
None
,
pixel_consistency
=
False
,
device
=
None
):
if
image3
is
None
:
image3
=
image1
padder
=
InputPadder
(
image1
.
shape
,
padding_factor
=
8
)
image1
,
image2
=
padder
.
pad
(
image1
[
None
].
cuda
(
),
image2
[
None
].
cuda
(
))
image1
,
image2
=
padder
.
pad
(
image1
[
None
].
to
(
device
),
image2
[
None
].
to
(
device
))
results_dict
=
flow_model
(
image1
,
image2
,
attn_splits_list
=
[
2
],
corr_radius_list
=
[
-
1
],
prop_radius_list
=
[
-
1
],
pred_bidir_flow
=
True
)
...
...
@@ -307,6 +307,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
feature_extractor
:
CLIPImageProcessor
,
image_encoder
=
None
,
requires_safety_checker
:
bool
=
True
,
device
=
None
,
):
super
().
__init__
(
vae
,
...
...
@@ -320,6 +321,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
image_encoder
,
requires_safety_checker
,
)
self
.
to
(
device
)
if
safety_checker
is
None
and
requires_safety_checker
:
logger
.
warning
(
...
...
@@ -374,7 +376,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
attention_type
=
"swin"
,
ffn_dim_expansion
=
4
,
num_transformer_layers
=
6
,
).
to
(
"cuda"
)
).
to
(
self
.
device
)
checkpoint
=
torch
.
utils
.
model_zoo
.
load_url
(
"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth"
,
...
...
@@ -928,13 +930,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
prev_image
=
self
.
image_processor
.
preprocess
(
prev_image
).
to
(
dtype
=
torch
.
float32
)
warped_0
,
bwd_occ_0
,
bwd_flow_0
=
get_warped_and_mask
(
self
.
flow_model
,
first_image
,
image
[
0
],
first_result
,
False
self
.
flow_model
,
first_image
,
image
[
0
],
first_result
,
False
,
self
.
device
)
blend_mask_0
=
blur
(
F
.
max_pool2d
(
bwd_occ_0
,
kernel_size
=
9
,
stride
=
1
,
padding
=
4
))
blend_mask_0
=
torch
.
clamp
(
blend_mask_0
+
bwd_occ_0
,
0
,
1
)
warped_pre
,
bwd_occ_pre
,
bwd_flow_pre
=
get_warped_and_mask
(
self
.
flow_model
,
prev_image
[
0
],
image
[
0
],
prev_result
,
False
self
.
flow_model
,
prev_image
[
0
],
image
[
0
],
prev_result
,
False
,
self
.
device
)
blend_mask_pre
=
blur
(
F
.
max_pool2d
(
bwd_occ_pre
,
kernel_size
=
9
,
stride
=
1
,
padding
=
4
))
blend_mask_pre
=
torch
.
clamp
(
blend_mask_pre
+
bwd_occ_pre
,
0
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment