Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DragGAN_pytorch
Commits
fba8bde8
Commit
fba8bde8
authored
Oct 29, 2024
by
bailuo
Browse files
update
parents
Pipeline
#1808
failed with stages
Changes
224
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
825 additions
and
0 deletions
+825
-0
viz/drag_widget.py
viz/drag_widget.py
+168
-0
viz/latent_widget.py
viz/latent_widget.py
+96
-0
viz/pickle_widget.py
viz/pickle_widget.py
+173
-0
viz/renderer.py
viz/renderer.py
+388
-0
No files found.
viz/drag_widget.py
0 → 100644
View file @
fba8bde8
import
os
import
torch
import
numpy
as
np
import
imgui
import
dnnlib
from
gui_utils
import
imgui_utils
#----------------------------------------------------------------------------
class
DragWidget
:
def
__init__
(
self
,
viz
):
self
.
viz
=
viz
self
.
point
=
[
-
1
,
-
1
]
self
.
points
=
[]
self
.
targets
=
[]
self
.
is_point
=
True
self
.
last_click
=
False
self
.
is_drag
=
False
self
.
iteration
=
0
self
.
mode
=
'point'
self
.
r_mask
=
50
self
.
show_mask
=
False
self
.
mask
=
torch
.
ones
(
256
,
256
)
self
.
lambda_mask
=
20
self
.
feature_idx
=
5
self
.
r1
=
3
self
.
r2
=
12
self
.
path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'_screenshots'
))
self
.
defer_frames
=
0
self
.
disabled_time
=
0
def
action
(
self
,
click
,
down
,
x
,
y
):
if
self
.
mode
==
'point'
:
self
.
add_point
(
click
,
x
,
y
)
elif
down
:
self
.
draw_mask
(
x
,
y
)
def
add_point
(
self
,
click
,
x
,
y
):
if
click
:
self
.
point
=
[
y
,
x
]
elif
self
.
last_click
:
if
self
.
is_drag
:
self
.
stop_drag
()
if
self
.
is_point
:
self
.
points
.
append
(
self
.
point
)
self
.
is_point
=
False
else
:
self
.
targets
.
append
(
self
.
point
)
self
.
is_point
=
True
self
.
last_click
=
click
def
init_mask
(
self
,
w
,
h
):
self
.
width
,
self
.
height
=
w
,
h
self
.
mask
=
torch
.
ones
(
h
,
w
)
def
draw_mask
(
self
,
x
,
y
):
X
=
torch
.
linspace
(
0
,
self
.
width
,
self
.
width
)
Y
=
torch
.
linspace
(
0
,
self
.
height
,
self
.
height
)
yy
,
xx
=
torch
.
meshgrid
(
Y
,
X
)
circle
=
(
xx
-
x
)
**
2
+
(
yy
-
y
)
**
2
<
self
.
r_mask
**
2
if
self
.
mode
==
'flexible'
:
self
.
mask
[
circle
]
=
0
elif
self
.
mode
==
'fixed'
:
self
.
mask
[
circle
]
=
1
def
stop_drag
(
self
):
self
.
is_drag
=
False
self
.
iteration
=
0
def
set_points
(
self
,
points
):
self
.
points
=
points
def
reset_point
(
self
):
self
.
points
=
[]
self
.
targets
=
[]
self
.
is_point
=
True
def
load_points
(
self
,
suffix
):
points
=
[]
point_path
=
self
.
path
+
f
'_
{
suffix
}
.txt'
try
:
with
open
(
point_path
,
"r"
)
as
f
:
for
line
in
f
.
readlines
():
y
,
x
=
line
.
split
()
points
.
append
([
int
(
y
),
int
(
x
)])
except
:
print
(
f
'Wrong point file path:
{
point_path
}
'
)
return
points
@
imgui_utils
.
scoped_by_object_id
def
__call__
(
self
,
show
=
True
):
viz
=
self
.
viz
reset
=
False
if
show
:
with
imgui_utils
.
grayed_out
(
self
.
disabled_time
!=
0
):
imgui
.
text
(
'Drag'
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Add point'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
mode
=
'point'
imgui
.
same_line
()
reset
=
False
if
imgui_utils
.
button
(
'Reset point'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
reset_point
()
reset
=
True
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Start'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
is_drag
=
True
if
len
(
self
.
points
)
>
len
(
self
.
targets
):
self
.
points
=
self
.
points
[:
len
(
self
.
targets
)]
imgui
.
same_line
()
if
imgui_utils
.
button
(
'Stop'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
stop_drag
()
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
imgui
.
text
(
f
'Steps:
{
self
.
iteration
}
'
)
imgui
.
text
(
'Mask'
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Flexible area'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
mode
=
'flexible'
self
.
show_mask
=
True
imgui
.
same_line
()
if
imgui_utils
.
button
(
'Fixed area'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
mode
=
'fixed'
self
.
show_mask
=
True
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Reset mask'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
):
self
.
mask
=
torch
.
ones
(
self
.
height
,
self
.
width
)
imgui
.
same_line
()
_clicked
,
self
.
show_mask
=
imgui
.
checkbox
(
'Show mask'
,
self
.
show_mask
)
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
font_size
*
6
):
changed
,
self
.
r_mask
=
imgui
.
input_int
(
'Radius'
,
self
.
r_mask
)
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
font_size
*
6
):
changed
,
self
.
lambda_mask
=
imgui
.
input_int
(
'Lambda'
,
self
.
lambda_mask
)
self
.
disabled_time
=
max
(
self
.
disabled_time
-
viz
.
frame_delta
,
0
)
if
self
.
defer_frames
>
0
:
self
.
defer_frames
-=
1
viz
.
args
.
is_drag
=
self
.
is_drag
if
self
.
is_drag
:
self
.
iteration
+=
1
viz
.
args
.
iteration
=
self
.
iteration
viz
.
args
.
points
=
[
point
for
point
in
self
.
points
]
viz
.
args
.
targets
=
[
point
for
point
in
self
.
targets
]
viz
.
args
.
mask
=
self
.
mask
viz
.
args
.
lambda_mask
=
self
.
lambda_mask
viz
.
args
.
feature_idx
=
self
.
feature_idx
viz
.
args
.
r1
=
self
.
r1
viz
.
args
.
r2
=
self
.
r2
viz
.
args
.
reset
=
reset
#----------------------------------------------------------------------------
viz/latent_widget.py
0 → 100644
View file @
fba8bde8
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
os
import
numpy
as
np
import
imgui
import
dnnlib
import
torch
from
gui_utils
import
imgui_utils
#----------------------------------------------------------------------------
class
LatentWidget
:
def
__init__
(
self
,
viz
):
self
.
viz
=
viz
self
.
seed
=
0
self
.
w_plus
=
True
self
.
reg
=
0
self
.
lr
=
0.001
self
.
w_path
=
''
self
.
w_load
=
None
self
.
defer_frames
=
0
self
.
disabled_time
=
0
@
imgui_utils
.
scoped_by_object_id
def
__call__
(
self
,
show
=
True
):
viz
=
self
.
viz
if
show
:
with
imgui_utils
.
grayed_out
(
self
.
disabled_time
!=
0
):
imgui
.
text
(
'Latent'
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
font_size
*
8.75
):
changed
,
seed
=
imgui
.
input_int
(
'Seed'
,
self
.
seed
)
if
changed
:
self
.
seed
=
seed
# reset latent code
self
.
w_load
=
None
# load latent code
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
_changed
,
self
.
w_path
=
imgui_utils
.
input_text
(
'##path'
,
self
.
w_path
,
1024
,
flags
=
(
imgui
.
INPUT_TEXT_AUTO_SELECT_ALL
|
imgui
.
INPUT_TEXT_ENTER_RETURNS_TRUE
),
width
=
(
-
1
),
help_text
=
'Path to latent code'
)
if
imgui
.
is_item_hovered
()
and
not
imgui
.
is_item_active
()
and
self
.
w_path
!=
''
:
imgui
.
set_tooltip
(
self
.
w_path
)
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Load latent'
,
width
=
viz
.
button_w
,
enabled
=
(
self
.
disabled_time
==
0
and
'image'
in
viz
.
result
)):
assert
os
.
path
.
isfile
(
self
.
w_path
),
f
"
{
self
.
w_path
}
does not exist!"
self
.
w_load
=
torch
.
load
(
self
.
w_path
)
self
.
defer_frames
=
2
self
.
disabled_time
=
0.5
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
button_w
):
changed
,
lr
=
imgui
.
input_float
(
'Step Size'
,
self
.
lr
)
if
changed
:
self
.
lr
=
lr
# imgui.text(' ')
# imgui.same_line(viz.label_w)
# with imgui_utils.item_width(viz.button_w):
# changed, reg = imgui.input_float('Regularize', self.reg)
# if changed:
# self.reg = reg
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
reset_w
=
imgui_utils
.
button
(
'Reset'
,
width
=
viz
.
button_w
,
enabled
=
'image'
in
viz
.
result
)
imgui
.
same_line
()
_clicked
,
w
=
imgui
.
checkbox
(
'w'
,
not
self
.
w_plus
)
if
w
:
self
.
w_plus
=
False
imgui
.
same_line
()
_clicked
,
self
.
w_plus
=
imgui
.
checkbox
(
'w+'
,
self
.
w_plus
)
self
.
disabled_time
=
max
(
self
.
disabled_time
-
viz
.
frame_delta
,
0
)
if
self
.
defer_frames
>
0
:
self
.
defer_frames
-=
1
viz
.
args
.
w0_seed
=
self
.
seed
viz
.
args
.
w_load
=
self
.
w_load
viz
.
args
.
reg
=
self
.
reg
viz
.
args
.
w_plus
=
self
.
w_plus
viz
.
args
.
reset_w
=
reset_w
viz
.
args
.
lr
=
lr
#----------------------------------------------------------------------------
viz/pickle_widget.py
0 → 100644
View file @
fba8bde8
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
glob
import
os
import
re
import
dnnlib
import
imgui
import
numpy
as
np
from
gui_utils
import
imgui_utils
from
.
import
renderer
#----------------------------------------------------------------------------
def
_locate_results
(
pattern
):
return
pattern
#----------------------------------------------------------------------------
class
PickleWidget
:
def
__init__
(
self
,
viz
):
self
.
viz
=
viz
self
.
search_dirs
=
[]
self
.
cur_pkl
=
None
self
.
user_pkl
=
''
self
.
recent_pkls
=
[]
self
.
browse_cache
=
dict
()
# {tuple(path, ...): [dnnlib.EasyDict(), ...], ...}
self
.
browse_refocus
=
False
self
.
load
(
''
,
ignore_errors
=
True
)
def
add_recent
(
self
,
pkl
,
ignore_errors
=
False
):
try
:
resolved
=
self
.
resolve_pkl
(
pkl
)
if
resolved
not
in
self
.
recent_pkls
:
self
.
recent_pkls
.
append
(
resolved
)
except
:
if
not
ignore_errors
:
raise
def
load
(
self
,
pkl
,
ignore_errors
=
False
):
viz
=
self
.
viz
viz
.
clear_result
()
viz
.
skip_frame
()
# The input field will change on next frame.
try
:
resolved
=
self
.
resolve_pkl
(
pkl
)
name
=
resolved
.
replace
(
'
\\
'
,
'/'
).
split
(
'/'
)[
-
1
]
self
.
cur_pkl
=
resolved
self
.
user_pkl
=
resolved
viz
.
result
.
message
=
f
'Loading
{
name
}
...'
viz
.
defer_rendering
()
if
resolved
in
self
.
recent_pkls
:
self
.
recent_pkls
.
remove
(
resolved
)
self
.
recent_pkls
.
insert
(
0
,
resolved
)
except
:
self
.
cur_pkl
=
None
self
.
user_pkl
=
pkl
if
pkl
==
''
:
viz
.
result
=
dnnlib
.
EasyDict
(
message
=
'No network pickle loaded'
)
else
:
viz
.
result
=
dnnlib
.
EasyDict
(
error
=
renderer
.
CapturedException
())
if
not
ignore_errors
:
raise
@
imgui_utils
.
scoped_by_object_id
def
__call__
(
self
,
show
=
True
):
viz
=
self
.
viz
recent_pkls
=
[
pkl
for
pkl
in
self
.
recent_pkls
if
pkl
!=
self
.
user_pkl
]
if
show
:
imgui
.
text
(
'Pickle'
)
imgui
.
same_line
(
viz
.
label_w
)
idx
=
self
.
user_pkl
.
rfind
(
'/'
)
changed
,
self
.
user_pkl
=
imgui_utils
.
input_text
(
'##pkl'
,
self
.
user_pkl
[
idx
+
1
:],
1024
,
flags
=
(
imgui
.
INPUT_TEXT_AUTO_SELECT_ALL
|
imgui
.
INPUT_TEXT_ENTER_RETURNS_TRUE
),
width
=
(
-
1
),
help_text
=
'<PATH> | <URL> | <RUN_DIR> | <RUN_ID> | <RUN_ID>/<KIMG>.pkl'
)
if
changed
:
self
.
load
(
self
.
user_pkl
,
ignore_errors
=
True
)
if
imgui
.
is_item_hovered
()
and
not
imgui
.
is_item_active
()
and
self
.
user_pkl
!=
''
:
imgui
.
set_tooltip
(
self
.
user_pkl
)
# imgui.same_line()
imgui
.
text
(
' '
)
imgui
.
same_line
(
viz
.
label_w
)
if
imgui_utils
.
button
(
'Recent...'
,
width
=
viz
.
button_w
,
enabled
=
(
len
(
recent_pkls
)
!=
0
)):
imgui
.
open_popup
(
'recent_pkls_popup'
)
imgui
.
same_line
()
if
imgui_utils
.
button
(
'Browse...'
,
enabled
=
len
(
self
.
search_dirs
)
>
0
,
width
=
viz
.
button_w
):
imgui
.
open_popup
(
'browse_pkls_popup'
)
self
.
browse_cache
.
clear
()
self
.
browse_refocus
=
True
if
imgui
.
begin_popup
(
'recent_pkls_popup'
):
for
pkl
in
recent_pkls
:
clicked
,
_state
=
imgui
.
menu_item
(
pkl
)
if
clicked
:
self
.
load
(
pkl
,
ignore_errors
=
True
)
imgui
.
end_popup
()
if
imgui
.
begin_popup
(
'browse_pkls_popup'
):
def
recurse
(
parents
):
key
=
tuple
(
parents
)
items
=
self
.
browse_cache
.
get
(
key
,
None
)
if
items
is
None
:
items
=
self
.
list_runs_and_pkls
(
parents
)
self
.
browse_cache
[
key
]
=
items
for
item
in
items
:
if
item
.
type
==
'run'
and
imgui
.
begin_menu
(
item
.
name
):
recurse
([
item
.
path
])
imgui
.
end_menu
()
if
item
.
type
==
'pkl'
:
clicked
,
_state
=
imgui
.
menu_item
(
item
.
name
)
if
clicked
:
self
.
load
(
item
.
path
,
ignore_errors
=
True
)
if
len
(
items
)
==
0
:
with
imgui_utils
.
grayed_out
():
imgui
.
menu_item
(
'No results found'
)
recurse
(
self
.
search_dirs
)
if
self
.
browse_refocus
:
imgui
.
set_scroll_here
()
viz
.
skip_frame
()
# Focus will change on next frame.
self
.
browse_refocus
=
False
imgui
.
end_popup
()
paths
=
viz
.
pop_drag_and_drop_paths
()
if
paths
is
not
None
and
len
(
paths
)
>=
1
:
self
.
load
(
paths
[
0
],
ignore_errors
=
True
)
viz
.
args
.
pkl
=
self
.
cur_pkl
def
list_runs_and_pkls
(
self
,
parents
):
items
=
[]
run_regex
=
re
.
compile
(
r
'\d+-.*'
)
pkl_regex
=
re
.
compile
(
r
'network-snapshot-\d+\.pkl'
)
for
parent
in
set
(
parents
):
if
os
.
path
.
isdir
(
parent
):
for
entry
in
os
.
scandir
(
parent
):
if
entry
.
is_dir
()
and
run_regex
.
fullmatch
(
entry
.
name
):
items
.
append
(
dnnlib
.
EasyDict
(
type
=
'run'
,
name
=
entry
.
name
,
path
=
os
.
path
.
join
(
parent
,
entry
.
name
)))
if
entry
.
is_file
()
and
pkl_regex
.
fullmatch
(
entry
.
name
):
items
.
append
(
dnnlib
.
EasyDict
(
type
=
'pkl'
,
name
=
entry
.
name
,
path
=
os
.
path
.
join
(
parent
,
entry
.
name
)))
items
=
sorted
(
items
,
key
=
lambda
item
:
(
item
.
name
.
replace
(
'_'
,
' '
),
item
.
path
))
return
items
def
resolve_pkl
(
self
,
pattern
):
assert
isinstance
(
pattern
,
str
)
assert
pattern
!=
''
# URL => return as is.
if
dnnlib
.
util
.
is_url
(
pattern
):
return
pattern
# Short-hand pattern => locate.
path
=
_locate_results
(
pattern
)
# Run dir => pick the last saved snapshot.
if
os
.
path
.
isdir
(
path
):
pkl_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
path
,
'network-snapshot-*.pkl'
)))
if
len
(
pkl_files
)
==
0
:
raise
IOError
(
f
'No network pickle found in "
{
path
}
"'
)
path
=
pkl_files
[
-
1
]
# Normalize.
path
=
os
.
path
.
abspath
(
path
)
return
path
#----------------------------------------------------------------------------
viz/renderer.py
0 → 100644
View file @
fba8bde8
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from
socket
import
has_dualstack_ipv6
import
sys
import
copy
import
traceback
import
math
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
torch
import
torch.fft
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
matplotlib.cm
import
dnnlib
from
torch_utils.ops
import
upfirdn2d
import
legacy
# pylint: disable=import-error
#----------------------------------------------------------------------------
class
CapturedException
(
Exception
):
def
__init__
(
self
,
msg
=
None
):
if
msg
is
None
:
_type
,
value
,
_traceback
=
sys
.
exc_info
()
assert
value
is
not
None
if
isinstance
(
value
,
CapturedException
):
msg
=
str
(
value
)
else
:
msg
=
traceback
.
format_exc
()
assert
isinstance
(
msg
,
str
)
super
().
__init__
(
msg
)
#----------------------------------------------------------------------------
class
CaptureSuccess
(
Exception
):
def
__init__
(
self
,
out
):
super
().
__init__
()
self
.
out
=
out
#----------------------------------------------------------------------------
def
add_watermark_np
(
input_image_array
,
watermark_text
=
"AI Generated"
):
image
=
Image
.
fromarray
(
np
.
uint8
(
input_image_array
)).
convert
(
"RGBA"
)
# Initialize text image
txt
=
Image
.
new
(
'RGBA'
,
image
.
size
,
(
255
,
255
,
255
,
0
))
font
=
ImageFont
.
truetype
(
'arial.ttf'
,
round
(
25
/
512
*
image
.
size
[
0
]))
d
=
ImageDraw
.
Draw
(
txt
)
text_width
,
text_height
=
font
.
getsize
(
watermark_text
)
text_position
=
(
image
.
size
[
0
]
-
text_width
-
10
,
image
.
size
[
1
]
-
text_height
-
10
)
text_color
=
(
255
,
255
,
255
,
128
)
# white color with the alpha channel set to semi-transparent
# Draw the text onto the text canvas
d
.
text
(
text_position
,
watermark_text
,
font
=
font
,
fill
=
text_color
)
# Combine the image with the watermark
watermarked
=
Image
.
alpha_composite
(
image
,
txt
)
watermarked_array
=
np
.
array
(
watermarked
)
return
watermarked_array
#----------------------------------------------------------------------------
class
Renderer
:
def
__init__
(
self
,
disable_timing
=
False
):
self
.
_device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'mps'
if
torch
.
backends
.
mps
.
is_available
()
else
'cpu'
)
self
.
_dtype
=
torch
.
float32
if
self
.
_device
.
type
==
'mps'
else
torch
.
float64
self
.
_pkl_data
=
dict
()
# {pkl: dict | CapturedException, ...}
self
.
_networks
=
dict
()
# {cache_key: torch.nn.Module, ...}
self
.
_pinned_bufs
=
dict
()
# {(shape, dtype): torch.Tensor, ...}
self
.
_cmaps
=
dict
()
# {name: torch.Tensor, ...}
self
.
_is_timing
=
False
if
not
disable_timing
:
self
.
_start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
_end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
_disable_timing
=
disable_timing
self
.
_net_layers
=
dict
()
# {cache_key: [dnnlib.EasyDict, ...], ...}
def
render
(
self
,
**
args
):
if
self
.
_disable_timing
:
self
.
_is_timing
=
False
else
:
self
.
_start_event
.
record
(
torch
.
cuda
.
current_stream
(
self
.
_device
))
self
.
_is_timing
=
True
res
=
dnnlib
.
EasyDict
()
try
:
init_net
=
False
if
not
hasattr
(
self
,
'G'
):
init_net
=
True
if
hasattr
(
self
,
'pkl'
):
if
self
.
pkl
!=
args
[
'pkl'
]:
init_net
=
True
if
hasattr
(
self
,
'w_load'
):
if
self
.
w_load
is
not
args
[
'w_load'
]:
init_net
=
True
if
hasattr
(
self
,
'w0_seed'
):
if
self
.
w0_seed
!=
args
[
'w0_seed'
]:
init_net
=
True
if
hasattr
(
self
,
'w_plus'
):
if
self
.
w_plus
!=
args
[
'w_plus'
]:
init_net
=
True
if
args
[
'reset_w'
]:
init_net
=
True
res
.
init_net
=
init_net
if
init_net
:
self
.
init_network
(
res
,
**
args
)
self
.
_render_drag_impl
(
res
,
**
args
)
except
:
res
.
error
=
CapturedException
()
if
not
self
.
_disable_timing
:
self
.
_end_event
.
record
(
torch
.
cuda
.
current_stream
(
self
.
_device
))
if
'image'
in
res
:
res
.
image
=
self
.
to_cpu
(
res
.
image
).
detach
().
numpy
()
res
.
image
=
add_watermark_np
(
res
.
image
,
'AI Generated'
)
if
'stats'
in
res
:
res
.
stats
=
self
.
to_cpu
(
res
.
stats
).
detach
().
numpy
()
if
'error'
in
res
:
res
.
error
=
str
(
res
.
error
)
# if 'stop' in res and res.stop:
if
self
.
_is_timing
and
not
self
.
_disable_timing
:
self
.
_end_event
.
synchronize
()
res
.
render_time
=
self
.
_start_event
.
elapsed_time
(
self
.
_end_event
)
*
1e-3
self
.
_is_timing
=
False
return
res
def
get_network
(
self
,
pkl
,
key
,
**
tweak_kwargs
):
data
=
self
.
_pkl_data
.
get
(
pkl
,
None
)
if
data
is
None
:
print
(
f
'Loading "
{
pkl
}
"... '
,
end
=
''
,
flush
=
True
)
try
:
with
dnnlib
.
util
.
open_url
(
pkl
,
verbose
=
False
)
as
f
:
data
=
legacy
.
load_network_pkl
(
f
)
print
(
'Done.'
)
except
:
data
=
CapturedException
()
print
(
'Failed!'
)
self
.
_pkl_data
[
pkl
]
=
data
self
.
_ignore_timing
()
if
isinstance
(
data
,
CapturedException
):
raise
data
orig_net
=
data
[
key
]
cache_key
=
(
orig_net
,
self
.
_device
,
tuple
(
sorted
(
tweak_kwargs
.
items
())))
net
=
self
.
_networks
.
get
(
cache_key
,
None
)
if
net
is
None
:
try
:
if
'stylegan2'
in
pkl
:
from
training.networks_stylegan2
import
Generator
elif
'stylegan3'
in
pkl
:
from
training.networks_stylegan3
import
Generator
elif
'stylegan_human'
in
pkl
:
from
stylegan_human.training_scripts.sg2.training.networks
import
Generator
else
:
raise
NameError
(
'Cannot infer model type from pkl name!'
)
print
(
data
[
key
].
init_args
)
print
(
data
[
key
].
init_kwargs
)
if
'stylegan_human'
in
pkl
:
net
=
Generator
(
*
data
[
key
].
init_args
,
**
data
[
key
].
init_kwargs
,
square
=
False
,
padding
=
True
)
else
:
net
=
Generator
(
*
data
[
key
].
init_args
,
**
data
[
key
].
init_kwargs
)
net
.
load_state_dict
(
data
[
key
].
state_dict
())
net
.
to
(
self
.
_device
)
except
:
net
=
CapturedException
()
self
.
_networks
[
cache_key
]
=
net
self
.
_ignore_timing
()
if
isinstance
(
net
,
CapturedException
):
raise
net
return
net
def
_get_pinned_buf
(
self
,
ref
):
key
=
(
tuple
(
ref
.
shape
),
ref
.
dtype
)
buf
=
self
.
_pinned_bufs
.
get
(
key
,
None
)
if
buf
is
None
:
buf
=
torch
.
empty
(
ref
.
shape
,
dtype
=
ref
.
dtype
).
pin_memory
()
self
.
_pinned_bufs
[
key
]
=
buf
return
buf
def
to_device
(
self
,
buf
):
return
self
.
_get_pinned_buf
(
buf
).
copy_
(
buf
).
to
(
self
.
_device
)
def
to_cpu
(
self
,
buf
):
return
self
.
_get_pinned_buf
(
buf
).
copy_
(
buf
).
clone
()
def
_ignore_timing
(
self
):
self
.
_is_timing
=
False
def
_apply_cmap
(
self
,
x
,
name
=
'viridis'
):
cmap
=
self
.
_cmaps
.
get
(
name
,
None
)
if
cmap
is
None
:
cmap
=
matplotlib
.
cm
.
get_cmap
(
name
)
cmap
=
cmap
(
np
.
linspace
(
0
,
1
,
num
=
1024
),
bytes
=
True
)[:,
:
3
]
cmap
=
self
.
to_device
(
torch
.
from_numpy
(
cmap
))
self
.
_cmaps
[
name
]
=
cmap
hi
=
cmap
.
shape
[
0
]
-
1
x
=
(
x
*
hi
+
0.5
).
clamp
(
0
,
hi
).
to
(
torch
.
int64
)
x
=
torch
.
nn
.
functional
.
embedding
(
x
,
cmap
)
return
x
def
init_network
(
self
,
res
,
pkl
=
None
,
w0_seed
=
0
,
w_load
=
None
,
w_plus
=
True
,
noise_mode
=
'const'
,
trunc_psi
=
0.7
,
trunc_cutoff
=
None
,
input_transform
=
None
,
lr
=
0.001
,
**
kwargs
):
# Dig up network details.
self
.
pkl
=
pkl
G
=
self
.
get_network
(
pkl
,
'G_ema'
)
self
.
G
=
G
res
.
img_resolution
=
G
.
img_resolution
res
.
num_ws
=
G
.
num_ws
res
.
has_noise
=
any
(
'noise_const'
in
name
for
name
,
_buf
in
G
.
synthesis
.
named_buffers
())
res
.
has_input_transform
=
(
hasattr
(
G
.
synthesis
,
'input'
)
and
hasattr
(
G
.
synthesis
.
input
,
'transform'
))
# Set input transform.
if
res
.
has_input_transform
:
m
=
np
.
eye
(
3
)
try
:
if
input_transform
is
not
None
:
m
=
np
.
linalg
.
inv
(
np
.
asarray
(
input_transform
))
except
np
.
linalg
.
LinAlgError
:
res
.
error
=
CapturedException
()
G
.
synthesis
.
input
.
transform
.
copy_
(
torch
.
from_numpy
(
m
))
# Generate random latents.
self
.
w0_seed
=
w0_seed
self
.
w_load
=
w_load
if
self
.
w_load
is
None
:
# Generate random latents.
z
=
torch
.
from_numpy
(
np
.
random
.
RandomState
(
w0_seed
).
randn
(
1
,
512
)).
to
(
self
.
_device
,
dtype
=
self
.
_dtype
)
# Run mapping network.
label
=
torch
.
zeros
([
1
,
G
.
c_dim
],
device
=
self
.
_device
)
w
=
G
.
mapping
(
z
,
label
,
truncation_psi
=
trunc_psi
,
truncation_cutoff
=
trunc_cutoff
)
else
:
w
=
self
.
w_load
.
clone
().
to
(
self
.
_device
)
self
.
w0
=
w
.
detach
().
clone
()
self
.
w_plus
=
w_plus
if
w_plus
:
self
.
w
=
w
.
detach
()
else
:
self
.
w
=
w
[:,
0
,
:].
detach
()
self
.
w
.
requires_grad
=
True
self
.
w_optim
=
torch
.
optim
.
Adam
([
self
.
w
],
lr
=
lr
)
self
.
feat_refs
=
None
self
.
points0_pt
=
None
def
update_lr
(
self
,
lr
):
del
self
.
w_optim
self
.
w_optim
=
torch
.
optim
.
Adam
([
self
.
w
],
lr
=
lr
)
print
(
f
'Rebuild optimizer with lr:
{
lr
}
'
)
print
(
' Remain feat_refs and points0_pt'
)
def
_render_drag_impl
(
self
,
res
,
points
=
[],
targets
=
[],
mask
=
None
,
lambda_mask
=
10
,
reg
=
0
,
feature_idx
=
5
,
r1
=
3
,
r2
=
12
,
random_seed
=
0
,
noise_mode
=
'const'
,
trunc_psi
=
0.7
,
force_fp32
=
False
,
layer_name
=
None
,
sel_channels
=
3
,
base_channel
=
0
,
img_scale_db
=
0
,
img_normalize
=
False
,
untransform
=
False
,
is_drag
=
False
,
reset
=
False
,
to_pil
=
False
,
**
kwargs
):
G
=
self
.
G
ws
=
self
.
w
if
ws
.
dim
()
==
2
:
ws
=
ws
.
unsqueeze
(
1
).
repeat
(
1
,
6
,
1
)
ws
=
torch
.
cat
([
ws
[:,:
6
,:],
self
.
w0
[:,
6
:,:]],
dim
=
1
)
if
hasattr
(
self
,
'points'
):
if
len
(
points
)
!=
len
(
self
.
points
):
reset
=
True
if
reset
:
self
.
feat_refs
=
None
self
.
points0_pt
=
None
self
.
points
=
points
# Run synthesis network.
label
=
torch
.
zeros
([
1
,
G
.
c_dim
],
device
=
self
.
_device
)
img
,
feat
=
G
(
ws
,
label
,
truncation_psi
=
trunc_psi
,
noise_mode
=
noise_mode
,
input_is_w
=
True
,
return_feature
=
True
)
h
,
w
=
G
.
img_resolution
,
G
.
img_resolution
if
is_drag
:
X
=
torch
.
linspace
(
0
,
h
,
h
)
Y
=
torch
.
linspace
(
0
,
w
,
w
)
xx
,
yy
=
torch
.
meshgrid
(
X
,
Y
)
feat_resize
=
F
.
interpolate
(
feat
[
feature_idx
],
[
h
,
w
],
mode
=
'bilinear'
)
if
self
.
feat_refs
is
None
:
self
.
feat0_resize
=
F
.
interpolate
(
feat
[
feature_idx
].
detach
(),
[
h
,
w
],
mode
=
'bilinear'
)
self
.
feat_refs
=
[]
for
point
in
points
:
py
,
px
=
round
(
point
[
0
]),
round
(
point
[
1
])
self
.
feat_refs
.
append
(
self
.
feat0_resize
[:,:,
py
,
px
])
self
.
points0_pt
=
torch
.
Tensor
(
points
).
unsqueeze
(
0
).
to
(
self
.
_device
)
# 1, N, 2
# Point tracking with feature matching
with
torch
.
no_grad
():
for
j
,
point
in
enumerate
(
points
):
r
=
round
(
r2
/
512
*
h
)
up
=
max
(
point
[
0
]
-
r
,
0
)
down
=
min
(
point
[
0
]
+
r
+
1
,
h
)
left
=
max
(
point
[
1
]
-
r
,
0
)
right
=
min
(
point
[
1
]
+
r
+
1
,
w
)
feat_patch
=
feat_resize
[:,:,
up
:
down
,
left
:
right
]
L2
=
torch
.
linalg
.
norm
(
feat_patch
-
self
.
feat_refs
[
j
].
reshape
(
1
,
-
1
,
1
,
1
),
dim
=
1
)
_
,
idx
=
torch
.
min
(
L2
.
view
(
1
,
-
1
),
-
1
)
width
=
right
-
left
point
=
[
idx
.
item
()
//
width
+
up
,
idx
.
item
()
%
width
+
left
]
points
[
j
]
=
point
res
.
points
=
[[
point
[
0
],
point
[
1
]]
for
point
in
points
]
# Motion supervision
loss_motion
=
0
res
.
stop
=
True
for
j
,
point
in
enumerate
(
points
):
direction
=
torch
.
Tensor
([
targets
[
j
][
1
]
-
point
[
1
],
targets
[
j
][
0
]
-
point
[
0
]])
if
torch
.
linalg
.
norm
(
direction
)
>
max
(
2
/
512
*
h
,
2
):
res
.
stop
=
False
if
torch
.
linalg
.
norm
(
direction
)
>
1
:
distance
=
((
xx
.
to
(
self
.
_device
)
-
point
[
0
])
**
2
+
(
yy
.
to
(
self
.
_device
)
-
point
[
1
])
**
2
)
**
0.5
relis
,
reljs
=
torch
.
where
(
distance
<
round
(
r1
/
512
*
h
))
direction
=
direction
/
(
torch
.
linalg
.
norm
(
direction
)
+
1e-7
)
gridh
=
(
relis
+
direction
[
1
])
/
(
h
-
1
)
*
2
-
1
gridw
=
(
reljs
+
direction
[
0
])
/
(
w
-
1
)
*
2
-
1
grid
=
torch
.
stack
([
gridw
,
gridh
],
dim
=-
1
).
unsqueeze
(
0
).
unsqueeze
(
0
)
target
=
F
.
grid_sample
(
feat_resize
.
float
(),
grid
,
align_corners
=
True
).
squeeze
(
2
)
loss_motion
+=
F
.
l1_loss
(
feat_resize
[:,:,
relis
,
reljs
].
detach
(),
target
)
loss
=
loss_motion
if
mask
is
not
None
:
if
mask
.
min
()
==
0
and
mask
.
max
()
==
1
:
mask_usq
=
mask
.
to
(
self
.
_device
).
unsqueeze
(
0
).
unsqueeze
(
0
)
loss_fix
=
F
.
l1_loss
(
feat_resize
*
mask_usq
,
self
.
feat0_resize
*
mask_usq
)
loss
+=
lambda_mask
*
loss_fix
loss
+=
reg
*
F
.
l1_loss
(
ws
,
self
.
w0
)
# latent code regularization
if
not
res
.
stop
:
self
.
w_optim
.
zero_grad
()
loss
.
backward
()
self
.
w_optim
.
step
()
# Scale and convert to uint8.
img
=
img
[
0
]
if
img_normalize
:
img
=
img
/
img
.
norm
(
float
(
'inf'
),
dim
=
[
1
,
2
],
keepdim
=
True
).
clip
(
1e-8
,
1e8
)
img
=
img
*
(
10
**
(
img_scale_db
/
20
))
img
=
(
img
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
permute
(
1
,
2
,
0
)
if
to_pil
:
from
PIL
import
Image
img
=
img
.
cpu
().
numpy
()
img
=
Image
.
fromarray
(
img
)
res
.
image
=
img
res
.
w
=
ws
.
detach
().
cpu
().
numpy
()
#----------------------------------------------------------------------------
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment