Unverified Commit b24f7834 authored by pravdomil's avatar pravdomil Committed by GitHub
Browse files

use self.device (#6595)

* use self.device

* use device

* fix

* fix
parent 3ce905c9
...@@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5) ...@@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5)
@torch.no_grad() @torch.no_grad()
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False): def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):
if image3 is None: if image3 is None:
image3 = image1 image3 = image1
padder = InputPadder(image1.shape, padding_factor=8) padder = InputPadder(image1.shape, padding_factor=8)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))
results_dict = flow_model( results_dict = flow_model(
image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True
) )
...@@ -307,6 +307,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): ...@@ -307,6 +307,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder=None, image_encoder=None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
device=None,
): ):
super().__init__( super().__init__(
vae, vae,
...@@ -320,6 +321,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): ...@@ -320,6 +321,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
image_encoder, image_encoder,
requires_safety_checker, requires_safety_checker,
) )
self.to(device)
if safety_checker is None and requires_safety_checker: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
...@@ -374,7 +376,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): ...@@ -374,7 +376,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
attention_type="swin", attention_type="swin",
ffn_dim_expansion=4, ffn_dim_expansion=4,
num_transformer_layers=6, num_transformer_layers=6,
).to("cuda") ).to(self.device)
checkpoint = torch.utils.model_zoo.load_url( checkpoint = torch.utils.model_zoo.load_url(
"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth", "https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
...@@ -928,13 +930,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): ...@@ -928,13 +930,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32) prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
self.flow_model, first_image, image[0], first_result, False self.flow_model, first_image, image[0], first_result, False, self.device
) )
blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
self.flow_model, prev_image[0], image[0], prev_result, False self.flow_model, prev_image[0], image[0], prev_result, False, self.device
) )
blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment