Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan3_pytorch
Commits
310493b2
Commit
310493b2
authored
Jan 05, 2024
by
mashun1
Browse files
stylegan3
parents
Pipeline
#695
canceled with stages
Changes
103
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
518 additions
and
0 deletions
+518
-0
viz/renderer.py
viz/renderer.py
+377
-0
viz/stylemix_widget.py
viz/stylemix_widget.py
+66
-0
viz/trunc_noise_widget.py
viz/trunc_noise_widget.py
+75
-0
No files found.
viz/renderer.py
0 → 100644
View file @
310493b2
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
sys
import
copy
import
traceback
import
numpy
as
np
import
torch
import
torch.fft
import
torch.nn
import
matplotlib.cm
import
dnnlib
from
torch_utils.ops
import
upfirdn2d
import
legacy
# pylint: disable=import-error
#----------------------------------------------------------------------------
class
CapturedException
(
Exception
):
def
__init__
(
self
,
msg
=
None
):
if
msg
is
None
:
_type
,
value
,
_traceback
=
sys
.
exc_info
()
assert
value
is
not
None
if
isinstance
(
value
,
CapturedException
):
msg
=
str
(
value
)
else
:
msg
=
traceback
.
format_exc
()
assert
isinstance
(
msg
,
str
)
super
().
__init__
(
msg
)
#----------------------------------------------------------------------------
class
CaptureSuccess
(
Exception
):
def
__init__
(
self
,
out
):
super
().
__init__
()
self
.
out
=
out
#----------------------------------------------------------------------------
def
_sinc
(
x
):
y
=
(
x
*
np
.
pi
).
abs
()
z
=
torch
.
sin
(
y
)
/
y
.
clamp
(
1e-30
,
float
(
'inf'
))
return
torch
.
where
(
y
<
1e-30
,
torch
.
ones_like
(
x
),
z
)
def
_lanczos_window
(
x
,
a
):
x
=
x
.
abs
()
/
a
return
torch
.
where
(
x
<
1
,
_sinc
(
x
),
torch
.
zeros_like
(
x
))
#----------------------------------------------------------------------------
def
_construct_affine_bandlimit_filter
(
mat
,
a
=
3
,
amax
=
16
,
aflt
=
64
,
up
=
4
,
cutoff_in
=
1
,
cutoff_out
=
1
):
assert
a
<=
amax
<
aflt
mat
=
torch
.
as_tensor
(
mat
).
to
(
torch
.
float32
)
# Construct 2D filter taps in input & output coordinate spaces.
taps
=
((
torch
.
arange
(
aflt
*
up
*
2
-
1
,
device
=
mat
.
device
)
+
1
)
/
up
-
aflt
).
roll
(
1
-
aflt
*
up
)
yi
,
xi
=
torch
.
meshgrid
(
taps
,
taps
)
xo
,
yo
=
(
torch
.
stack
([
xi
,
yi
],
dim
=
2
)
@
mat
[:
2
,
:
2
].
t
()).
unbind
(
2
)
# Convolution of two oriented 2D sinc filters.
fi
=
_sinc
(
xi
*
cutoff_in
)
*
_sinc
(
yi
*
cutoff_in
)
fo
=
_sinc
(
xo
*
cutoff_out
)
*
_sinc
(
yo
*
cutoff_out
)
f
=
torch
.
fft
.
ifftn
(
torch
.
fft
.
fftn
(
fi
)
*
torch
.
fft
.
fftn
(
fo
)).
real
# Convolution of two oriented 2D Lanczos windows.
wi
=
_lanczos_window
(
xi
,
a
)
*
_lanczos_window
(
yi
,
a
)
wo
=
_lanczos_window
(
xo
,
a
)
*
_lanczos_window
(
yo
,
a
)
w
=
torch
.
fft
.
ifftn
(
torch
.
fft
.
fftn
(
wi
)
*
torch
.
fft
.
fftn
(
wo
)).
real
# Construct windowed FIR filter.
f
=
f
*
w
# Finalize.
c
=
(
aflt
-
amax
)
*
up
f
=
f
.
roll
([
aflt
*
up
-
1
]
*
2
,
dims
=
[
0
,
1
])[
c
:
-
c
,
c
:
-
c
]
f
=
torch
.
nn
.
functional
.
pad
(
f
,
[
0
,
1
,
0
,
1
]).
reshape
(
amax
*
2
,
up
,
amax
*
2
,
up
)
f
=
f
/
f
.
sum
([
0
,
2
],
keepdim
=
True
)
/
(
up
**
2
)
f
=
f
.
reshape
(
amax
*
2
*
up
,
amax
*
2
*
up
)[:
-
1
,
:
-
1
]
return
f
#----------------------------------------------------------------------------
def
_apply_affine_transformation
(
x
,
mat
,
up
=
4
,
**
filter_kwargs
):
_N
,
_C
,
H
,
W
=
x
.
shape
mat
=
torch
.
as_tensor
(
mat
).
to
(
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# Construct filter.
f
=
_construct_affine_bandlimit_filter
(
mat
,
up
=
up
,
**
filter_kwargs
)
assert
f
.
ndim
==
2
and
f
.
shape
[
0
]
==
f
.
shape
[
1
]
and
f
.
shape
[
0
]
%
2
==
1
p
=
f
.
shape
[
0
]
//
2
# Construct sampling grid.
theta
=
mat
.
inverse
()
theta
[:
2
,
2
]
*=
2
theta
[
0
,
2
]
+=
1
/
up
/
W
theta
[
1
,
2
]
+=
1
/
up
/
H
theta
[
0
,
:]
*=
W
/
(
W
+
p
/
up
*
2
)
theta
[
1
,
:]
*=
H
/
(
H
+
p
/
up
*
2
)
theta
=
theta
[:
2
,
:
3
].
unsqueeze
(
0
).
repeat
([
x
.
shape
[
0
],
1
,
1
])
g
=
torch
.
nn
.
functional
.
affine_grid
(
theta
,
x
.
shape
,
align_corners
=
False
)
# Resample image.
y
=
upfirdn2d
.
upsample2d
(
x
=
x
,
f
=
f
,
up
=
up
,
padding
=
p
)
z
=
torch
.
nn
.
functional
.
grid_sample
(
y
,
g
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
# Form mask.
m
=
torch
.
zeros_like
(
y
)
c
=
p
*
2
+
1
m
[:,
:,
c
:
-
c
,
c
:
-
c
]
=
1
m
=
torch
.
nn
.
functional
.
grid_sample
(
m
,
g
,
mode
=
'nearest'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
return
z
,
m
#----------------------------------------------------------------------------
class
Renderer
:
def
__init__
(
self
):
self
.
_device
=
torch
.
device
(
'cuda'
)
self
.
_pkl_data
=
dict
()
# {pkl: dict | CapturedException, ...}
self
.
_networks
=
dict
()
# {cache_key: torch.nn.Module, ...}
self
.
_pinned_bufs
=
dict
()
# {(shape, dtype): torch.Tensor, ...}
self
.
_cmaps
=
dict
()
# {name: torch.Tensor, ...}
self
.
_is_timing
=
False
self
.
_start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
_end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
_net_layers
=
dict
()
# {cache_key: [dnnlib.EasyDict, ...], ...}
def
render
(
self
,
**
args
):
self
.
_is_timing
=
True
self
.
_start_event
.
record
(
torch
.
cuda
.
current_stream
(
self
.
_device
))
res
=
dnnlib
.
EasyDict
()
try
:
self
.
_render_impl
(
res
,
**
args
)
except
:
res
.
error
=
CapturedException
()
self
.
_end_event
.
record
(
torch
.
cuda
.
current_stream
(
self
.
_device
))
if
'image'
in
res
:
res
.
image
=
self
.
to_cpu
(
res
.
image
).
numpy
()
if
'stats'
in
res
:
res
.
stats
=
self
.
to_cpu
(
res
.
stats
).
numpy
()
if
'error'
in
res
:
res
.
error
=
str
(
res
.
error
)
if
self
.
_is_timing
:
self
.
_end_event
.
synchronize
()
res
.
render_time
=
self
.
_start_event
.
elapsed_time
(
self
.
_end_event
)
*
1e-3
self
.
_is_timing
=
False
return
res
def
get_network
(
self
,
pkl
,
key
,
**
tweak_kwargs
):
data
=
self
.
_pkl_data
.
get
(
pkl
,
None
)
if
data
is
None
:
print
(
f
'Loading "
{
pkl
}
"... '
,
end
=
''
,
flush
=
True
)
try
:
with
dnnlib
.
util
.
open_url
(
pkl
,
verbose
=
False
)
as
f
:
data
=
legacy
.
load_network_pkl
(
f
)
print
(
'Done.'
)
except
:
data
=
CapturedException
()
print
(
'Failed!'
)
self
.
_pkl_data
[
pkl
]
=
data
self
.
_ignore_timing
()
if
isinstance
(
data
,
CapturedException
):
raise
data
orig_net
=
data
[
key
]
cache_key
=
(
orig_net
,
self
.
_device
,
tuple
(
sorted
(
tweak_kwargs
.
items
())))
net
=
self
.
_networks
.
get
(
cache_key
,
None
)
if
net
is
None
:
try
:
net
=
copy
.
deepcopy
(
orig_net
)
net
=
self
.
_tweak_network
(
net
,
**
tweak_kwargs
)
net
.
to
(
self
.
_device
)
except
:
net
=
CapturedException
()
self
.
_networks
[
cache_key
]
=
net
self
.
_ignore_timing
()
if
isinstance
(
net
,
CapturedException
):
raise
net
return
net
def
_tweak_network
(
self
,
net
):
# Print diagnostics.
#for name, value in misc.named_params_and_buffers(net):
# if name.endswith('.magnitude_ema'):
# value = value.rsqrt().numpy()
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}')
# if name.endswith('.weight') and value.ndim == 4:
# value = value.square().mean([1,2,3]).sqrt().numpy()
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}')
return
net
def
_get_pinned_buf
(
self
,
ref
):
key
=
(
tuple
(
ref
.
shape
),
ref
.
dtype
)
buf
=
self
.
_pinned_bufs
.
get
(
key
,
None
)
if
buf
is
None
:
buf
=
torch
.
empty
(
ref
.
shape
,
dtype
=
ref
.
dtype
).
pin_memory
()
self
.
_pinned_bufs
[
key
]
=
buf
return
buf
def
to_device
(
self
,
buf
):
return
self
.
_get_pinned_buf
(
buf
).
copy_
(
buf
).
to
(
self
.
_device
)
def
to_cpu
(
self
,
buf
):
return
self
.
_get_pinned_buf
(
buf
).
copy_
(
buf
).
clone
()
def
_ignore_timing
(
self
):
self
.
_is_timing
=
False
def
_apply_cmap
(
self
,
x
,
name
=
'viridis'
):
cmap
=
self
.
_cmaps
.
get
(
name
,
None
)
if
cmap
is
None
:
cmap
=
matplotlib
.
cm
.
get_cmap
(
name
)
cmap
=
cmap
(
np
.
linspace
(
0
,
1
,
num
=
1024
),
bytes
=
True
)[:,
:
3
]
cmap
=
self
.
to_device
(
torch
.
from_numpy
(
cmap
))
self
.
_cmaps
[
name
]
=
cmap
hi
=
cmap
.
shape
[
0
]
-
1
x
=
(
x
*
hi
+
0.5
).
clamp
(
0
,
hi
).
to
(
torch
.
int64
)
x
=
torch
.
nn
.
functional
.
embedding
(
x
,
cmap
)
return
x
def
_render_impl
(
self
,
res
,
pkl
=
None
,
w0_seeds
=
[[
0
,
1
]],
stylemix_idx
=
[],
stylemix_seed
=
0
,
trunc_psi
=
1
,
trunc_cutoff
=
0
,
random_seed
=
0
,
noise_mode
=
'const'
,
force_fp32
=
False
,
layer_name
=
None
,
sel_channels
=
3
,
base_channel
=
0
,
img_scale_db
=
0
,
img_normalize
=
False
,
fft_show
=
False
,
fft_all
=
True
,
fft_range_db
=
50
,
fft_beta
=
8
,
input_transform
=
None
,
untransform
=
False
,
):
# Dig up network details.
G
=
self
.
get_network
(
pkl
,
'G_ema'
)
res
.
img_resolution
=
G
.
img_resolution
res
.
num_ws
=
G
.
num_ws
res
.
has_noise
=
any
(
'noise_const'
in
name
for
name
,
_buf
in
G
.
synthesis
.
named_buffers
())
res
.
has_input_transform
=
(
hasattr
(
G
.
synthesis
,
'input'
)
and
hasattr
(
G
.
synthesis
.
input
,
'transform'
))
# Set input transform.
if
res
.
has_input_transform
:
m
=
np
.
eye
(
3
)
try
:
if
input_transform
is
not
None
:
m
=
np
.
linalg
.
inv
(
np
.
asarray
(
input_transform
))
except
np
.
linalg
.
LinAlgError
:
res
.
error
=
CapturedException
()
G
.
synthesis
.
input
.
transform
.
copy_
(
torch
.
from_numpy
(
m
))
# Generate random latents.
all_seeds
=
[
seed
for
seed
,
_weight
in
w0_seeds
]
+
[
stylemix_seed
]
all_seeds
=
list
(
set
(
all_seeds
))
all_zs
=
np
.
zeros
([
len
(
all_seeds
),
G
.
z_dim
],
dtype
=
np
.
float32
)
all_cs
=
np
.
zeros
([
len
(
all_seeds
),
G
.
c_dim
],
dtype
=
np
.
float32
)
for
idx
,
seed
in
enumerate
(
all_seeds
):
rnd
=
np
.
random
.
RandomState
(
seed
)
all_zs
[
idx
]
=
rnd
.
randn
(
G
.
z_dim
)
if
G
.
c_dim
>
0
:
all_cs
[
idx
,
rnd
.
randint
(
G
.
c_dim
)]
=
1
# Run mapping network.
w_avg
=
G
.
mapping
.
w_avg
all_zs
=
self
.
to_device
(
torch
.
from_numpy
(
all_zs
))
all_cs
=
self
.
to_device
(
torch
.
from_numpy
(
all_cs
))
all_ws
=
G
.
mapping
(
z
=
all_zs
,
c
=
all_cs
,
truncation_psi
=
trunc_psi
,
truncation_cutoff
=
trunc_cutoff
)
-
w_avg
all_ws
=
dict
(
zip
(
all_seeds
,
all_ws
))
# Calculate final W.
w
=
torch
.
stack
([
all_ws
[
seed
]
*
weight
for
seed
,
weight
in
w0_seeds
]).
sum
(
dim
=
0
,
keepdim
=
True
)
stylemix_idx
=
[
idx
for
idx
in
stylemix_idx
if
0
<=
idx
<
G
.
num_ws
]
if
len
(
stylemix_idx
)
>
0
:
w
[:,
stylemix_idx
]
=
all_ws
[
stylemix_seed
][
np
.
newaxis
,
stylemix_idx
]
w
+=
w_avg
# Run synthesis network.
synthesis_kwargs
=
dnnlib
.
EasyDict
(
noise_mode
=
noise_mode
,
force_fp32
=
force_fp32
)
torch
.
manual_seed
(
random_seed
)
out
,
layers
=
self
.
run_synthesis_net
(
G
.
synthesis
,
w
,
capture_layer
=
layer_name
,
**
synthesis_kwargs
)
# Update layer list.
cache_key
=
(
G
.
synthesis
,
tuple
(
sorted
(
synthesis_kwargs
.
items
())))
if
cache_key
not
in
self
.
_net_layers
:
if
layer_name
is
not
None
:
torch
.
manual_seed
(
random_seed
)
_out
,
layers
=
self
.
run_synthesis_net
(
G
.
synthesis
,
w
,
**
synthesis_kwargs
)
self
.
_net_layers
[
cache_key
]
=
layers
res
.
layers
=
self
.
_net_layers
[
cache_key
]
# Untransform.
if
untransform
and
res
.
has_input_transform
:
out
,
_mask
=
_apply_affine_transformation
(
out
.
to
(
torch
.
float32
),
G
.
synthesis
.
input
.
transform
,
amax
=
6
)
# Override amax to hit the fast path in upfirdn2d.
# Select channels and compute statistics.
out
=
out
[
0
].
to
(
torch
.
float32
)
if
sel_channels
>
out
.
shape
[
0
]:
sel_channels
=
1
base_channel
=
max
(
min
(
base_channel
,
out
.
shape
[
0
]
-
sel_channels
),
0
)
sel
=
out
[
base_channel
:
base_channel
+
sel_channels
]
res
.
stats
=
torch
.
stack
([
out
.
mean
(),
sel
.
mean
(),
out
.
std
(),
sel
.
std
(),
out
.
norm
(
float
(
'inf'
)),
sel
.
norm
(
float
(
'inf'
)),
])
# Scale and convert to uint8.
img
=
sel
if
img_normalize
:
img
=
img
/
img
.
norm
(
float
(
'inf'
),
dim
=
[
1
,
2
],
keepdim
=
True
).
clip
(
1e-8
,
1e8
)
img
=
img
*
(
10
**
(
img_scale_db
/
20
))
img
=
(
img
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
permute
(
1
,
2
,
0
)
res
.
image
=
img
# FFT.
if
fft_show
:
sig
=
out
if
fft_all
else
sel
sig
=
sig
.
to
(
torch
.
float32
)
sig
=
sig
-
sig
.
mean
(
dim
=
[
1
,
2
],
keepdim
=
True
)
sig
=
sig
*
torch
.
kaiser_window
(
sig
.
shape
[
1
],
periodic
=
False
,
beta
=
fft_beta
,
device
=
self
.
_device
)[
None
,
:,
None
]
sig
=
sig
*
torch
.
kaiser_window
(
sig
.
shape
[
2
],
periodic
=
False
,
beta
=
fft_beta
,
device
=
self
.
_device
)[
None
,
None
,
:]
fft
=
torch
.
fft
.
fftn
(
sig
,
dim
=
[
1
,
2
]).
abs
().
square
().
sum
(
dim
=
0
)
fft
=
fft
.
roll
(
shifts
=
[
fft
.
shape
[
0
]
//
2
,
fft
.
shape
[
1
]
//
2
],
dims
=
[
0
,
1
])
fft
=
(
fft
/
fft
.
mean
()).
log10
()
*
10
# dB
fft
=
self
.
_apply_cmap
((
fft
/
fft_range_db
+
1
)
/
2
)
res
.
image
=
torch
.
cat
([
img
.
expand_as
(
fft
),
fft
],
dim
=
1
)
@
staticmethod
def
run_synthesis_net
(
net
,
*
args
,
capture_layer
=
None
,
**
kwargs
):
# => out, layers
submodule_names
=
{
mod
:
name
for
name
,
mod
in
net
.
named_modules
()}
unique_names
=
set
()
layers
=
[]
def
module_hook
(
module
,
_inputs
,
outputs
):
outputs
=
list
(
outputs
)
if
isinstance
(
outputs
,
(
tuple
,
list
))
else
[
outputs
]
outputs
=
[
out
for
out
in
outputs
if
isinstance
(
out
,
torch
.
Tensor
)
and
out
.
ndim
in
[
4
,
5
]]
for
idx
,
out
in
enumerate
(
outputs
):
if
out
.
ndim
==
5
:
# G-CNN => remove group dimension.
out
=
out
.
mean
(
2
)
name
=
submodule_names
[
module
]
if
name
==
''
:
name
=
'output'
if
len
(
outputs
)
>
1
:
name
+=
f
':
{
idx
}
'
if
name
in
unique_names
:
suffix
=
2
while
f
'
{
name
}
_
{
suffix
}
'
in
unique_names
:
suffix
+=
1
name
+=
f
'_
{
suffix
}
'
unique_names
.
add
(
name
)
shape
=
[
int
(
x
)
for
x
in
out
.
shape
]
dtype
=
str
(
out
.
dtype
).
split
(
'.'
)[
-
1
]
layers
.
append
(
dnnlib
.
EasyDict
(
name
=
name
,
shape
=
shape
,
dtype
=
dtype
))
if
name
==
capture_layer
:
raise
CaptureSuccess
(
out
)
hooks
=
[
module
.
register_forward_hook
(
module_hook
)
for
module
in
net
.
modules
()]
try
:
out
=
net
(
*
args
,
**
kwargs
)
except
CaptureSuccess
as
e
:
out
=
e
.
out
for
hook
in
hooks
:
hook
.
remove
()
return
out
,
layers
#----------------------------------------------------------------------------
viz/stylemix_widget.py
0 → 100644
View file @
310493b2
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
imgui
from
gui_utils
import
imgui_utils
#----------------------------------------------------------------------------
class
StyleMixingWidget
:
def
__init__
(
self
,
viz
):
self
.
viz
=
viz
self
.
seed_def
=
1000
self
.
seed
=
self
.
seed_def
self
.
animate
=
False
self
.
enables
=
[]
@
imgui_utils
.
scoped_by_object_id
def
__call__
(
self
,
show
=
True
):
viz
=
self
.
viz
num_ws
=
viz
.
result
.
get
(
'num_ws'
,
0
)
num_enables
=
viz
.
result
.
get
(
'num_ws'
,
18
)
self
.
enables
+=
[
False
]
*
max
(
num_enables
-
len
(
self
.
enables
),
0
)
if
show
:
imgui
.
text
(
'Stylemix'
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
font_size
*
8
),
imgui_utils
.
grayed_out
(
num_ws
==
0
):
_changed
,
self
.
seed
=
imgui
.
input_int
(
'##seed'
,
self
.
seed
)
imgui
.
same_line
(
viz
.
label_w
+
viz
.
font_size
*
8
+
viz
.
spacing
)
with
imgui_utils
.
grayed_out
(
num_ws
==
0
):
_clicked
,
self
.
animate
=
imgui
.
checkbox
(
'Anim'
,
self
.
animate
)
pos2
=
imgui
.
get_content_region_max
()[
0
]
-
1
-
viz
.
button_w
pos1
=
pos2
-
imgui
.
get_text_line_height
()
-
viz
.
spacing
pos0
=
viz
.
label_w
+
viz
.
font_size
*
12
imgui
.
push_style_var
(
imgui
.
STYLE_FRAME_PADDING
,
[
0
,
0
])
for
idx
in
range
(
num_enables
):
imgui
.
same_line
(
round
(
pos0
+
(
pos1
-
pos0
)
*
(
idx
/
(
num_enables
-
1
))))
if
idx
==
0
:
imgui
.
set_cursor_pos_y
(
imgui
.
get_cursor_pos_y
()
+
3
)
with
imgui_utils
.
grayed_out
(
num_ws
==
0
):
_clicked
,
self
.
enables
[
idx
]
=
imgui
.
checkbox
(
f
'##
{
idx
}
'
,
self
.
enables
[
idx
])
if
imgui
.
is_item_hovered
():
imgui
.
set_tooltip
(
f
'
{
idx
}
'
)
imgui
.
pop_style_var
(
1
)
imgui
.
same_line
(
pos2
)
imgui
.
set_cursor_pos_y
(
imgui
.
get_cursor_pos_y
()
-
3
)
with
imgui_utils
.
grayed_out
(
num_ws
==
0
):
if
imgui_utils
.
button
(
'Reset'
,
width
=-
1
,
enabled
=
(
self
.
seed
!=
self
.
seed_def
or
self
.
animate
or
any
(
self
.
enables
[:
num_enables
]))):
self
.
seed
=
self
.
seed_def
self
.
animate
=
False
self
.
enables
=
[
False
]
*
num_enables
if
any
(
self
.
enables
[:
num_ws
]):
viz
.
args
.
stylemix_idx
=
[
idx
for
idx
,
enable
in
enumerate
(
self
.
enables
)
if
enable
]
viz
.
args
.
stylemix_seed
=
self
.
seed
&
((
1
<<
32
)
-
1
)
if
self
.
animate
:
self
.
seed
+=
1
#----------------------------------------------------------------------------
viz/trunc_noise_widget.py
0 → 100644
View file @
310493b2
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
imgui
from
gui_utils
import
imgui_utils
#----------------------------------------------------------------------------
class
TruncationNoiseWidget
:
def
__init__
(
self
,
viz
):
self
.
viz
=
viz
self
.
prev_num_ws
=
0
self
.
trunc_psi
=
1
self
.
trunc_cutoff
=
0
self
.
noise_enable
=
True
self
.
noise_seed
=
0
self
.
noise_anim
=
False
@
imgui_utils
.
scoped_by_object_id
def
__call__
(
self
,
show
=
True
):
viz
=
self
.
viz
num_ws
=
viz
.
result
.
get
(
'num_ws'
,
0
)
has_noise
=
viz
.
result
.
get
(
'has_noise'
,
False
)
if
num_ws
>
0
and
num_ws
!=
self
.
prev_num_ws
:
if
self
.
trunc_cutoff
>
num_ws
or
self
.
trunc_cutoff
==
self
.
prev_num_ws
:
self
.
trunc_cutoff
=
num_ws
self
.
prev_num_ws
=
num_ws
if
show
:
imgui
.
text
(
'Truncate'
)
imgui
.
same_line
(
viz
.
label_w
)
with
imgui_utils
.
item_width
(
viz
.
font_size
*
10
),
imgui_utils
.
grayed_out
(
num_ws
==
0
):
_changed
,
self
.
trunc_psi
=
imgui
.
slider_float
(
'##psi'
,
self
.
trunc_psi
,
-
1
,
2
,
format
=
'Psi %.2f'
)
imgui
.
same_line
()
if
num_ws
==
0
:
imgui_utils
.
button
(
'Cutoff 0'
,
width
=
(
viz
.
font_size
*
8
+
viz
.
spacing
),
enabled
=
False
)
else
:
with
imgui_utils
.
item_width
(
viz
.
font_size
*
8
+
viz
.
spacing
):
changed
,
new_cutoff
=
imgui
.
slider_int
(
'##cutoff'
,
self
.
trunc_cutoff
,
0
,
num_ws
,
format
=
'Cutoff %d'
)
if
changed
:
self
.
trunc_cutoff
=
min
(
max
(
new_cutoff
,
0
),
num_ws
)
with
imgui_utils
.
grayed_out
(
not
has_noise
):
imgui
.
same_line
()
_clicked
,
self
.
noise_enable
=
imgui
.
checkbox
(
'Noise##enable'
,
self
.
noise_enable
)
imgui
.
same_line
(
round
(
viz
.
font_size
*
27.7
))
with
imgui_utils
.
grayed_out
(
not
self
.
noise_enable
):
with
imgui_utils
.
item_width
(
-
1
-
viz
.
button_w
-
viz
.
spacing
-
viz
.
font_size
*
4
):
_changed
,
self
.
noise_seed
=
imgui
.
input_int
(
'##seed'
,
self
.
noise_seed
)
imgui
.
same_line
(
spacing
=
0
)
_clicked
,
self
.
noise_anim
=
imgui
.
checkbox
(
'Anim##noise'
,
self
.
noise_anim
)
is_def_trunc
=
(
self
.
trunc_psi
==
1
and
self
.
trunc_cutoff
==
num_ws
)
is_def_noise
=
(
self
.
noise_enable
and
self
.
noise_seed
==
0
and
not
self
.
noise_anim
)
with
imgui_utils
.
grayed_out
(
is_def_trunc
and
not
has_noise
):
imgui
.
same_line
(
imgui
.
get_content_region_max
()[
0
]
-
1
-
viz
.
button_w
)
if
imgui_utils
.
button
(
'Reset'
,
width
=-
1
,
enabled
=
(
not
is_def_trunc
or
not
is_def_noise
)):
self
.
prev_num_ws
=
num_ws
self
.
trunc_psi
=
1
self
.
trunc_cutoff
=
num_ws
self
.
noise_enable
=
True
self
.
noise_seed
=
0
self
.
noise_anim
=
False
if
self
.
noise_anim
:
self
.
noise_seed
+=
1
viz
.
args
.
update
(
trunc_psi
=
self
.
trunc_psi
,
trunc_cutoff
=
self
.
trunc_cutoff
,
random_seed
=
self
.
noise_seed
)
viz
.
args
.
noise_mode
=
(
'none'
if
not
self
.
noise_enable
else
'const'
if
self
.
noise_seed
==
0
else
'random'
)
#----------------------------------------------------------------------------
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment