Commit fba8bde8 authored by bailuo's avatar bailuo
Browse files

update

parents
Pipeline #1808 failed with stages
import os
import torch
import numpy as np
import imgui
import dnnlib
from gui_utils import imgui_utils
#----------------------------------------------------------------------------
class DragWidget:
def __init__(self, viz):
self.viz = viz
self.point = [-1, -1]
self.points = []
self.targets = []
self.is_point = True
self.last_click = False
self.is_drag = False
self.iteration = 0
self.mode = 'point'
self.r_mask = 50
self.show_mask = False
self.mask = torch.ones(256, 256)
self.lambda_mask = 20
self.feature_idx = 5
self.r1 = 3
self.r2 = 12
self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots'))
self.defer_frames = 0
self.disabled_time = 0
def action(self, click, down, x, y):
if self.mode == 'point':
self.add_point(click, x, y)
elif down:
self.draw_mask(x, y)
def add_point(self, click, x, y):
if click:
self.point = [y, x]
elif self.last_click:
if self.is_drag:
self.stop_drag()
if self.is_point:
self.points.append(self.point)
self.is_point = False
else:
self.targets.append(self.point)
self.is_point = True
self.last_click = click
def init_mask(self, w, h):
self.width, self.height = w, h
self.mask = torch.ones(h, w)
def draw_mask(self, x, y):
X = torch.linspace(0, self.width, self.width)
Y = torch.linspace(0, self.height, self.height)
yy, xx = torch.meshgrid(Y, X)
circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2
if self.mode == 'flexible':
self.mask[circle] = 0
elif self.mode == 'fixed':
self.mask[circle] = 1
def stop_drag(self):
self.is_drag = False
self.iteration = 0
def set_points(self, points):
self.points = points
def reset_point(self):
self.points = []
self.targets = []
self.is_point = True
def load_points(self, suffix):
points = []
point_path = self.path + f'_{suffix}.txt'
try:
with open(point_path, "r") as f:
for line in f.readlines():
y, x = line.split()
points.append([int(y), int(x)])
except:
print(f'Wrong point file path: {point_path}')
return points
@imgui_utils.scoped_by_object_id
def __call__(self, show=True):
viz = self.viz
reset = False
if show:
with imgui_utils.grayed_out(self.disabled_time != 0):
imgui.text('Drag')
imgui.same_line(viz.label_w)
if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'point'
imgui.same_line()
reset = False
if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result):
self.reset_point()
reset = True
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result):
self.is_drag = True
if len(self.points) > len(self.targets):
self.points = self.points[:len(self.targets)]
imgui.same_line()
if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result):
self.stop_drag()
imgui.text(' ')
imgui.same_line(viz.label_w)
imgui.text(f'Steps: {self.iteration}')
imgui.text('Mask')
imgui.same_line(viz.label_w)
if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'flexible'
self.show_mask = True
imgui.same_line()
if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'fixed'
self.show_mask = True
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result):
self.mask = torch.ones(self.height, self.width)
imgui.same_line()
_clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask)
imgui.text(' ')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.font_size * 6):
changed, self.r_mask = imgui.input_int('Radius', self.r_mask)
imgui.text(' ')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.font_size * 6):
changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask)
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
if self.defer_frames > 0:
self.defer_frames -= 1
viz.args.is_drag = self.is_drag
if self.is_drag:
self.iteration += 1
viz.args.iteration = self.iteration
viz.args.points = [point for point in self.points]
viz.args.targets = [point for point in self.targets]
viz.args.mask = self.mask
viz.args.lambda_mask = self.lambda_mask
viz.args.feature_idx = self.feature_idx
viz.args.r1 = self.r1
viz.args.r2 = self.r2
viz.args.reset = reset
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import numpy as np
import imgui
import dnnlib
import torch
from gui_utils import imgui_utils
#----------------------------------------------------------------------------
class LatentWidget:
def __init__(self, viz):
self.viz = viz
self.seed = 0
self.w_plus = True
self.reg = 0
self.lr = 0.001
self.w_path = ''
self.w_load = None
self.defer_frames = 0
self.disabled_time = 0
@imgui_utils.scoped_by_object_id
def __call__(self, show=True):
viz = self.viz
if show:
with imgui_utils.grayed_out(self.disabled_time != 0):
imgui.text('Latent')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.font_size * 8.75):
changed, seed = imgui.input_int('Seed', self.seed)
if changed:
self.seed = seed
# reset latent code
self.w_load = None
# load latent code
imgui.text(' ')
imgui.same_line(viz.label_w)
_changed, self.w_path = imgui_utils.input_text('##path', self.w_path, 1024,
flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
width=(-1),
help_text='Path to latent code')
if imgui.is_item_hovered() and not imgui.is_item_active() and self.w_path != '':
imgui.set_tooltip(self.w_path)
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Load latent', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)):
assert os.path.isfile(self.w_path), f"{self.w_path} does not exist!"
self.w_load = torch.load(self.w_path)
self.defer_frames = 2
self.disabled_time = 0.5
imgui.text(' ')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.button_w):
changed, lr = imgui.input_float('Step Size', self.lr)
if changed:
self.lr = lr
# imgui.text(' ')
# imgui.same_line(viz.label_w)
# with imgui_utils.item_width(viz.button_w):
# changed, reg = imgui.input_float('Regularize', self.reg)
# if changed:
# self.reg = reg
imgui.text(' ')
imgui.same_line(viz.label_w)
reset_w = imgui_utils.button('Reset', width=viz.button_w, enabled='image' in viz.result)
imgui.same_line()
_clicked, w = imgui.checkbox('w', not self.w_plus)
if w:
self.w_plus = False
imgui.same_line()
_clicked, self.w_plus = imgui.checkbox('w+', self.w_plus)
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
if self.defer_frames > 0:
self.defer_frames -= 1
viz.args.w0_seed = self.seed
viz.args.w_load = self.w_load
viz.args.reg = self.reg
viz.args.w_plus = self.w_plus
viz.args.reset_w = reset_w
viz.args.lr = lr
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import glob
import os
import re
import dnnlib
import imgui
import numpy as np
from gui_utils import imgui_utils
from . import renderer
#----------------------------------------------------------------------------
def _locate_results(pattern):
return pattern
#----------------------------------------------------------------------------
class PickleWidget:
def __init__(self, viz):
self.viz = viz
self.search_dirs = []
self.cur_pkl = None
self.user_pkl = ''
self.recent_pkls = []
self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...}
self.browse_refocus = False
self.load('', ignore_errors=True)
def add_recent(self, pkl, ignore_errors=False):
try:
resolved = self.resolve_pkl(pkl)
if resolved not in self.recent_pkls:
self.recent_pkls.append(resolved)
except:
if not ignore_errors:
raise
def load(self, pkl, ignore_errors=False):
viz = self.viz
viz.clear_result()
viz.skip_frame() # The input field will change on next frame.
try:
resolved = self.resolve_pkl(pkl)
name = resolved.replace('\\', '/').split('/')[-1]
self.cur_pkl = resolved
self.user_pkl = resolved
viz.result.message = f'Loading {name}...'
viz.defer_rendering()
if resolved in self.recent_pkls:
self.recent_pkls.remove(resolved)
self.recent_pkls.insert(0, resolved)
except:
self.cur_pkl = None
self.user_pkl = pkl
if pkl == '':
viz.result = dnnlib.EasyDict(message='No network pickle loaded')
else:
viz.result = dnnlib.EasyDict(error=renderer.CapturedException())
if not ignore_errors:
raise
@imgui_utils.scoped_by_object_id
def __call__(self, show=True):
viz = self.viz
recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl]
if show:
imgui.text('Pickle')
imgui.same_line(viz.label_w)
idx = self.user_pkl.rfind('/')
changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl[idx+1:], 1024,
flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
width=(-1),
help_text='<PATH> | <URL> | <RUN_DIR> | <RUN_ID> | <RUN_ID>/<KIMG>.pkl')
if changed:
self.load(self.user_pkl, ignore_errors=True)
if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '':
imgui.set_tooltip(self.user_pkl)
# imgui.same_line()
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)):
imgui.open_popup('recent_pkls_popup')
imgui.same_line()
if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=viz.button_w):
imgui.open_popup('browse_pkls_popup')
self.browse_cache.clear()
self.browse_refocus = True
if imgui.begin_popup('recent_pkls_popup'):
for pkl in recent_pkls:
clicked, _state = imgui.menu_item(pkl)
if clicked:
self.load(pkl, ignore_errors=True)
imgui.end_popup()
if imgui.begin_popup('browse_pkls_popup'):
def recurse(parents):
key = tuple(parents)
items = self.browse_cache.get(key, None)
if items is None:
items = self.list_runs_and_pkls(parents)
self.browse_cache[key] = items
for item in items:
if item.type == 'run' and imgui.begin_menu(item.name):
recurse([item.path])
imgui.end_menu()
if item.type == 'pkl':
clicked, _state = imgui.menu_item(item.name)
if clicked:
self.load(item.path, ignore_errors=True)
if len(items) == 0:
with imgui_utils.grayed_out():
imgui.menu_item('No results found')
recurse(self.search_dirs)
if self.browse_refocus:
imgui.set_scroll_here()
viz.skip_frame() # Focus will change on next frame.
self.browse_refocus = False
imgui.end_popup()
paths = viz.pop_drag_and_drop_paths()
if paths is not None and len(paths) >= 1:
self.load(paths[0], ignore_errors=True)
viz.args.pkl = self.cur_pkl
def list_runs_and_pkls(self, parents):
items = []
run_regex = re.compile(r'\d+-.*')
pkl_regex = re.compile(r'network-snapshot-\d+\.pkl')
for parent in set(parents):
if os.path.isdir(parent):
for entry in os.scandir(parent):
if entry.is_dir() and run_regex.fullmatch(entry.name):
items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name)))
if entry.is_file() and pkl_regex.fullmatch(entry.name):
items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name)))
items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path))
return items
def resolve_pkl(self, pattern):
assert isinstance(pattern, str)
assert pattern != ''
# URL => return as is.
if dnnlib.util.is_url(pattern):
return pattern
# Short-hand pattern => locate.
path = _locate_results(pattern)
# Run dir => pick the last saved snapshot.
if os.path.isdir(path):
pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl')))
if len(pkl_files) == 0:
raise IOError(f'No network pickle found in "{path}"')
path = pkl_files[-1]
# Normalize.
path = os.path.abspath(path)
return path
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from socket import has_dualstack_ipv6
import sys
import copy
import traceback
import math
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.cm
import dnnlib
from torch_utils.ops import upfirdn2d
import legacy # pylint: disable=import-error
#----------------------------------------------------------------------------
class CapturedException(Exception):
def __init__(self, msg=None):
if msg is None:
_type, value, _traceback = sys.exc_info()
assert value is not None
if isinstance(value, CapturedException):
msg = str(value)
else:
msg = traceback.format_exc()
assert isinstance(msg, str)
super().__init__(msg)
#----------------------------------------------------------------------------
class CaptureSuccess(Exception):
def __init__(self, out):
super().__init__()
self.out = out
#----------------------------------------------------------------------------
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
# Initialize text image
txt = Image.new('RGBA', image.size, (255, 255, 255, 0))
font = ImageFont.truetype('arial.ttf', round(25/512*image.size[0]))
d = ImageDraw.Draw(txt)
text_width, text_height = font.getsize(watermark_text)
text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10)
text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent
# Draw the text onto the text canvas
d.text(text_position, watermark_text, font=font, fill=text_color)
# Combine the image with the watermark
watermarked = Image.alpha_composite(image, txt)
watermarked_array = np.array(watermarked)
return watermarked_array
#----------------------------------------------------------------------------
class Renderer:
def __init__(self, disable_timing=False):
self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
self._networks = dict() # {cache_key: torch.nn.Module, ...}
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
self._cmaps = dict() # {name: torch.Tensor, ...}
self._is_timing = False
if not disable_timing:
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)
self._disable_timing = disable_timing
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
def render(self, **args):
if self._disable_timing:
self._is_timing = False
else:
self._start_event.record(torch.cuda.current_stream(self._device))
self._is_timing = True
res = dnnlib.EasyDict()
try:
init_net = False
if not hasattr(self, 'G'):
init_net = True
if hasattr(self, 'pkl'):
if self.pkl != args['pkl']:
init_net = True
if hasattr(self, 'w_load'):
if self.w_load is not args['w_load']:
init_net = True
if hasattr(self, 'w0_seed'):
if self.w0_seed != args['w0_seed']:
init_net = True
if hasattr(self, 'w_plus'):
if self.w_plus != args['w_plus']:
init_net = True
if args['reset_w']:
init_net = True
res.init_net = init_net
if init_net:
self.init_network(res, **args)
self._render_drag_impl(res, **args)
except:
res.error = CapturedException()
if not self._disable_timing:
self._end_event.record(torch.cuda.current_stream(self._device))
if 'image' in res:
res.image = self.to_cpu(res.image).detach().numpy()
res.image = add_watermark_np(res.image, 'AI Generated')
if 'stats' in res:
res.stats = self.to_cpu(res.stats).detach().numpy()
if 'error' in res:
res.error = str(res.error)
# if 'stop' in res and res.stop:
if self._is_timing and not self._disable_timing:
self._end_event.synchronize()
res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
self._is_timing = False
return res
def get_network(self, pkl, key, **tweak_kwargs):
data = self._pkl_data.get(pkl, None)
if data is None:
print(f'Loading "{pkl}"... ', end='', flush=True)
try:
with dnnlib.util.open_url(pkl, verbose=False) as f:
data = legacy.load_network_pkl(f)
print('Done.')
except:
data = CapturedException()
print('Failed!')
self._pkl_data[pkl] = data
self._ignore_timing()
if isinstance(data, CapturedException):
raise data
orig_net = data[key]
cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
net = self._networks.get(cache_key, None)
if net is None:
try:
if 'stylegan2' in pkl:
from training.networks_stylegan2 import Generator
elif 'stylegan3' in pkl:
from training.networks_stylegan3 import Generator
elif 'stylegan_human' in pkl:
from stylegan_human.training_scripts.sg2.training.networks import Generator
else:
raise NameError('Cannot infer model type from pkl name!')
print(data[key].init_args)
print(data[key].init_kwargs)
if 'stylegan_human' in pkl:
net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
else:
net = Generator(*data[key].init_args, **data[key].init_kwargs)
net.load_state_dict(data[key].state_dict())
net.to(self._device)
except:
net = CapturedException()
self._networks[cache_key] = net
self._ignore_timing()
if isinstance(net, CapturedException):
raise net
return net
def _get_pinned_buf(self, ref):
key = (tuple(ref.shape), ref.dtype)
buf = self._pinned_bufs.get(key, None)
if buf is None:
buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory()
self._pinned_bufs[key] = buf
return buf
def to_device(self, buf):
return self._get_pinned_buf(buf).copy_(buf).to(self._device)
def to_cpu(self, buf):
return self._get_pinned_buf(buf).copy_(buf).clone()
def _ignore_timing(self):
self._is_timing = False
def _apply_cmap(self, x, name='viridis'):
cmap = self._cmaps.get(name, None)
if cmap is None:
cmap = matplotlib.cm.get_cmap(name)
cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3]
cmap = self.to_device(torch.from_numpy(cmap))
self._cmaps[name] = cmap
hi = cmap.shape[0] - 1
x = (x * hi + 0.5).clamp(0, hi).to(torch.int64)
x = torch.nn.functional.embedding(x, cmap)
return x
def init_network(self, res,
pkl = None,
w0_seed = 0,
w_load = None,
w_plus = True,
noise_mode = 'const',
trunc_psi = 0.7,
trunc_cutoff = None,
input_transform = None,
lr = 0.001,
**kwargs
):
# Dig up network details.
self.pkl = pkl
G = self.get_network(pkl, 'G_ema')
self.G = G
res.img_resolution = G.img_resolution
res.num_ws = G.num_ws
res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
# Set input transform.
if res.has_input_transform:
m = np.eye(3)
try:
if input_transform is not None:
m = np.linalg.inv(np.asarray(input_transform))
except np.linalg.LinAlgError:
res.error = CapturedException()
G.synthesis.input.transform.copy_(torch.from_numpy(m))
# Generate random latents.
self.w0_seed = w0_seed
self.w_load = w_load
if self.w_load is None:
# Generate random latents.
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
# Run mapping network.
label = torch.zeros([1, G.c_dim], device=self._device)
w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
else:
w = self.w_load.clone().to(self._device)
self.w0 = w.detach().clone()
self.w_plus = w_plus
if w_plus:
self.w = w.detach()
else:
self.w = w[:, 0, :].detach()
self.w.requires_grad = True
self.w_optim = torch.optim.Adam([self.w], lr=lr)
self.feat_refs = None
self.points0_pt = None
def update_lr(self, lr):
del self.w_optim
self.w_optim = torch.optim.Adam([self.w], lr=lr)
print(f'Rebuild optimizer with lr: {lr}')
print(' Remain feat_refs and points0_pt')
def _render_drag_impl(self, res,
points = [],
targets = [],
mask = None,
lambda_mask = 10,
reg = 0,
feature_idx = 5,
r1 = 3,
r2 = 12,
random_seed = 0,
noise_mode = 'const',
trunc_psi = 0.7,
force_fp32 = False,
layer_name = None,
sel_channels = 3,
base_channel = 0,
img_scale_db = 0,
img_normalize = False,
untransform = False,
is_drag = False,
reset = False,
to_pil = False,
**kwargs
):
G = self.G
ws = self.w
if ws.dim() == 2:
ws = ws.unsqueeze(1).repeat(1,6,1)
ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
if hasattr(self, 'points'):
if len(points) != len(self.points):
reset = True
if reset:
self.feat_refs = None
self.points0_pt = None
self.points = points
# Run synthesis network.
label = torch.zeros([1, G.c_dim], device=self._device)
img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
h, w = G.img_resolution, G.img_resolution
if is_drag:
X = torch.linspace(0, h, h)
Y = torch.linspace(0, w, w)
xx, yy = torch.meshgrid(X, Y)
feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
if self.feat_refs is None:
self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
self.feat_refs = []
for point in points:
py, px = round(point[0]), round(point[1])
self.feat_refs.append(self.feat0_resize[:,:,py,px])
self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
# Point tracking with feature matching
with torch.no_grad():
for j, point in enumerate(points):
r = round(r2 / 512 * h)
up = max(point[0] - r, 0)
down = min(point[0] + r + 1, h)
left = max(point[1] - r, 0)
right = min(point[1] + r + 1, w)
feat_patch = feat_resize[:,:,up:down,left:right]
L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1)
_, idx = torch.min(L2.view(1,-1), -1)
width = right - left
point = [idx.item() // width + up, idx.item() % width + left]
points[j] = point
res.points = [[point[0], point[1]] for point in points]
# Motion supervision
loss_motion = 0
res.stop = True
for j, point in enumerate(points):
direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]])
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
res.stop = False
if torch.linalg.norm(direction) > 1:
distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
relis, reljs = torch.where(distance < round(r1 / 512 * h))
direction = direction / (torch.linalg.norm(direction) + 1e-7)
gridh = (relis+direction[1]) / (h-1) * 2 - 1
gridw = (reljs+direction[0]) / (w-1) * 2 - 1
grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs].detach(), target)
loss = loss_motion
if mask is not None:
if mask.min() == 0 and mask.max() == 1:
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq)
loss += lambda_mask * loss_fix
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
if not res.stop:
self.w_optim.zero_grad()
loss.backward()
self.w_optim.step()
# Scale and convert to uint8.
img = img[0]
if img_normalize:
img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
img = img * (10 ** (img_scale_db / 20))
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
if to_pil:
from PIL import Image
img = img.cpu().numpy()
img = Image.fromarray(img)
res.image = img
res.w = ws.detach().cpu().numpy()
#----------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment