Commit 310493b2 authored by mashun1's avatar mashun1
Browse files

stylegan3

parents
Pipeline #695 canceled with stages
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generate images using pretrained network pickle."""
import os
import re
from typing import List, Optional, Tuple, Union
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
#----------------------------------------------------------------------------
def parse_range(s: Union[str, List]) -> List[int]:
'''Parse a comma separated list of numbers or ranges and return a list of ints.
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
'''
if isinstance(s, list): return s
ranges = []
range_re = re.compile(r'^(\d+)-(\d+)$')
for p in s.split(','):
m = range_re.match(p)
if m:
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
else:
ranges.append(int(p))
return ranges
#----------------------------------------------------------------------------
def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
'''Parse a floating point 2-vector of syntax 'a,b'.
Example:
'0,1' returns (0,1)
'''
if isinstance(s, tuple): return s
parts = s.split(',')
if len(parts) == 2:
return (float(parts[0]), float(parts[1]))
raise ValueError(f'cannot parse 2-vector {s}')
#----------------------------------------------------------------------------
def make_transform(translate: Tuple[float,float], angle: float):
m = np.eye(3)
s = np.sin(angle/360.0*np.pi*2)
c = np.cos(angle/360.0*np.pi*2)
m[0][0] = c
m[0][1] = s
m[0][2] = translate[0]
m[1][0] = -s
m[1][1] = c
m[1][2] = translate[1]
return m
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
def generate_images(
network_pkl: str,
seeds: List[int],
truncation_psi: float,
noise_mode: str,
outdir: str,
translate: Tuple[float,float],
rotate: float,
class_idx: Optional[int]
):
"""Generate images using pretrained network pickle.
Examples:
\b
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
\b
# Generate uncurated images with truncation using the MetFaces-U dataset
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
os.makedirs(outdir, exist_ok=True)
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
raise click.ClickException('Must specify class label with --class when using a conditional network')
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
# Generate images.
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical
# operations in the network.
if hasattr(G.synthesis, 'input'):
m = make_transform(translate, rotate)
m = np.linalg.inv(m)
G.synthesis.input.transform.copy_(torch.from_numpy(m))
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_images() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generate lerp videos using pretrained network pickle."""
import copy
import os
import re
from typing import List, Optional, Tuple, Union
import click
import dnnlib
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import legacy
#----------------------------------------------------------------------------
def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
batch_size, channels, img_h, img_w = img.shape
if grid_w is None:
grid_w = batch_size // grid_h
assert batch_size == grid_w * grid_h
if float_to_uint8:
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
img = img.permute(2, 0, 3, 1, 4)
img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
if chw_to_hwc:
img = img.permute(1, 2, 0)
if to_numpy:
img = img.cpu().numpy()
return img
#----------------------------------------------------------------------------
def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs):
grid_w = grid_dims[0]
grid_h = grid_dims[1]
if num_keyframes is None:
if len(seeds) % (grid_w*grid_h) != 0:
raise ValueError('Number of input seeds must be divisible by grid W*H')
num_keyframes = len(seeds) // (grid_w*grid_h)
all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
for idx in range(num_keyframes*grid_h*grid_w):
all_seeds[idx] = seeds[idx % len(seeds)]
if shuffle_seed is not None:
rng = np.random.RandomState(seed=shuffle_seed)
rng.shuffle(all_seeds)
zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
ws = G.mapping(z=zs, c=None, truncation_psi=psi)
_ = G.synthesis(ws[:1]) # warm up
ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
# Interpolation.
grid = []
for yi in range(grid_h):
row = []
for xi in range(grid_w):
x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
row.append(interp)
grid.append(row)
# Render video.
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
for frame_idx in tqdm(range(num_keyframes * w_frames)):
imgs = []
for yi in range(grid_h):
for xi in range(grid_w):
interp = grid[yi][xi]
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
imgs.append(img)
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
video_out.close()
#----------------------------------------------------------------------------
def parse_range(s: Union[str, List[int]]) -> List[int]:
'''Parse a comma separated list of numbers or ranges and return a list of ints.
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
'''
if isinstance(s, list): return s
ranges = []
range_re = re.compile(r'^(\d+)-(\d+)$')
for p in s.split(','):
m = range_re.match(p)
if m:
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
else:
ranges.append(int(p))
return ranges
#----------------------------------------------------------------------------
def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
'''Parse a 'M,N' or 'MxN' integer tuple.
Example:
'4x2' returns (4,2)
'0,1' returns (0,1)
'''
if isinstance(s, tuple): return s
m = re.match(r'^(\d+)[x,](\d+)$', s)
if m:
return (int(m.group(1)), int(m.group(2)))
raise ValueError(f'cannot parse tuple {s}')
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
def generate_images(
network_pkl: str,
seeds: List[int],
shuffle_seed: Optional[int],
truncation_psi: float,
grid: Tuple[int,int],
num_keyframes: Optional[int],
w_frames: int,
output: str
):
"""Render a latent vector interpolation video.
Examples:
\b
# Render a 4x2 grid of interpolations for seeds 0 through 31.
python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
Animation length and seed keyframes:
The animation length is either determined based on the --seeds value or explicitly
specified using the --num-keyframes option.
When num keyframes is specified with --num-keyframes, the output video length
will be 'num_keyframes*w_frames' frames.
If --num-keyframes is not specified, the number of seeds given with
--seeds must be divisible by grid size W*H (--grid). In this case the
output video length will be '# seeds/(w*h)*w_frames' frames.
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi)
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_images() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import functools
import contextlib
import numpy as np
import OpenGL.GL as gl
import OpenGL.GL.ARB.texture_float
import dnnlib
#----------------------------------------------------------------------------
def init_egl():
assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
import OpenGL.EGL as egl
import ctypes
# Initialize EGL.
display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
assert display != egl.EGL_NO_DISPLAY
major = ctypes.c_int32()
minor = ctypes.c_int32()
ok = egl.eglInitialize(display, major, minor)
assert ok
assert major.value * 10 + minor.value >= 14
# Choose config.
config_attribs = [
egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
egl.EGL_NONE
]
configs = (ctypes.c_int32 * 1)()
num_configs = ctypes.c_int32()
ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
assert ok
assert num_configs.value == 1
config = configs[0]
# Create dummy pbuffer surface.
surface_attribs = [
egl.EGL_WIDTH, 1,
egl.EGL_HEIGHT, 1,
egl.EGL_NONE
]
surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
assert surface != egl.EGL_NO_SURFACE
# Setup GL context.
ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
assert ok
context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
assert context != egl.EGL_NO_CONTEXT
ok = egl.eglMakeCurrent(display, surface, surface, context)
assert ok
#----------------------------------------------------------------------------
_texture_formats = {
('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
}
def get_texture_format(dtype, channels):
return _texture_formats[(np.dtype(dtype).name, int(channels))]
#----------------------------------------------------------------------------
def prepare_texture_data(image):
image = np.asarray(image)
if image.ndim == 2:
image = image[:, :, np.newaxis]
if image.dtype.name == 'float64':
image = image.astype('float32')
return image
#----------------------------------------------------------------------------
def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
image = prepare_texture_data(image)
height, width, channels = image.shape
size = zoom * [width, height]
pos = pos - size * align
if rint:
pos = np.rint(pos)
fmt = get_texture_format(image.dtype, channels)
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glRasterPos2f(pos[0], pos[1])
gl.glPixelZoom(zoom[0], -zoom[1])
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
gl.glPopClientAttrib()
gl.glPopAttrib()
#----------------------------------------------------------------------------
def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
dtype = np.dtype(dtype)
fmt = get_texture_format(dtype, channels)
image = np.empty([height, width, channels], dtype=dtype)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
gl.glPopClientAttrib()
return np.flipud(image)
#----------------------------------------------------------------------------
class Texture:
def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
self.gl_id = None
self.bilinear = bilinear
self.mipmap = mipmap
# Determine size and dtype.
if image is not None:
image = prepare_texture_data(image)
self.height, self.width, self.channels = image.shape
self.dtype = image.dtype
else:
assert width is not None and height is not None
self.width = width
self.height = height
self.channels = channels if channels is not None else 3
self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
# Validate size and dtype.
assert isinstance(self.width, int) and self.width >= 0
assert isinstance(self.height, int) and self.height >= 0
assert isinstance(self.channels, int) and self.channels >= 1
assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
# Create texture object.
self.gl_id = gl.glGenTextures(1)
with self.bind():
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
self.update(image)
def delete(self):
if self.gl_id is not None:
gl.glDeleteTextures([self.gl_id])
self.gl_id = None
def __del__(self):
try:
self.delete()
except:
pass
@contextlib.contextmanager
def bind(self):
prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
yield
gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
def update(self, image):
if image is not None:
image = prepare_texture_data(image)
assert self.is_compatible(image=image)
with self.bind():
fmt = get_texture_format(self.dtype, self.channels)
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
if self.mipmap:
gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
gl.glPopClientAttrib()
def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
size = zoom * [self.width, self.height]
with self.bind():
gl.glPushAttrib(gl.GL_ENABLE_BIT)
gl.glEnable(gl.GL_TEXTURE_2D)
draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
gl.glPopAttrib()
def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
if image is not None:
if image.ndim != 3:
return False
ih, iw, ic = image.shape
if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
return False
if width is not None and self.width != width:
return False
if height is not None and self.height != height:
return False
if channels is not None and self.channels != channels:
return False
if dtype is not None and self.dtype != dtype:
return False
return True
#----------------------------------------------------------------------------
class Framebuffer:
def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
self.texture = texture
self.gl_id = None
self.gl_color = None
self.gl_depth_stencil = None
self.msaa = msaa
# Determine size and dtype.
if texture is not None:
assert isinstance(self.texture, Texture)
self.width = texture.width
self.height = texture.height
self.channels = texture.channels
self.dtype = texture.dtype
else:
assert width is not None and height is not None
self.width = width
self.height = height
self.channels = channels if channels is not None else 4
self.dtype = np.dtype(dtype) if dtype is not None else np.float32
# Validate size and dtype.
assert isinstance(self.width, int) and self.width >= 0
assert isinstance(self.height, int) and self.height >= 0
assert isinstance(self.channels, int) and self.channels >= 1
assert width is None or width == self.width
assert height is None or height == self.height
assert channels is None or channels == self.channels
assert dtype is None or dtype == self.dtype
# Create framebuffer object.
self.gl_id = gl.glGenFramebuffers(1)
with self.bind():
# Setup color buffer.
if self.texture is not None:
assert self.msaa == 0
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
else:
fmt = get_texture_format(self.dtype, self.channels)
self.gl_color = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
# Setup depth/stencil buffer.
self.gl_depth_stencil = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
def delete(self):
if self.gl_id is not None:
gl.glDeleteFramebuffers([self.gl_id])
self.gl_id = None
if self.gl_color is not None:
gl.glDeleteRenderbuffers(1, [self.gl_color])
self.gl_color = None
if self.gl_depth_stencil is not None:
gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
self.gl_depth_stencil = None
def __del__(self):
try:
self.delete()
except:
pass
@contextlib.contextmanager
def bind(self):
prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
if self.width is not None and self.height is not None:
gl.glViewport(0, 0, self.width, self.height)
yield
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
def blit(self, dst=None):
assert dst is None or isinstance(dst, Framebuffer)
with self.bind():
gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
#----------------------------------------------------------------------------
def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
assert vertices.ndim == 2 and vertices.shape[1] == 2
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
gl.glMatrixMode(gl.GL_MODELVIEW)
gl.glPushMatrix()
gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
gl.glTranslate(pos[0], pos[1], 0)
gl.glScale(size[0], size[1], 1)
gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
gl.glDrawArrays(mode, 0, vertices.shape[0])
gl.glPopMatrix()
gl.glPopAttrib()
gl.glPopClientAttrib()
#----------------------------------------------------------------------------
def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
assert pos2 is None or size is None
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
pos = pos - size * align
if rint:
pos = np.rint(pos)
rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
if np.min(rounding) == 0:
rounding *= 0
vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
@functools.lru_cache(maxsize=10000)
def _setup_rect(rx, ry):
t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
s = 1 - np.sin(t); c = 1 - np.cos(t)
x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
v = np.stack([x, y], axis=-1).reshape(-1, 2)
return v.astype('float32')
#----------------------------------------------------------------------------
def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
vertices = _setup_circle(float(hole))
draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
@functools.lru_cache(maxsize=10000)
def _setup_circle(hole):
t = np.linspace(0, np.pi * 2, 128)
s = np.sin(t); c = np.cos(t)
v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
return v.astype('float32')
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import time
import glfw
import OpenGL.GL as gl
from . import gl_utils
#----------------------------------------------------------------------------
class GlfwWindow: # pylint: disable=too-many-public-methods
def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
self._glfw_window = None
self._drawing_frame = False
self._frame_start_time = None
self._frame_delta = 0
self._fps_limit = None
self._vsync = None
self._skip_frames = 0
self._deferred_show = deferred_show
self._close_on_esc = close_on_esc
self._esc_pressed = False
self._drag_and_drop_paths = None
self._capture_next_frame = False
self._captured_frame = None
# Create window.
glfw.init()
glfw.window_hint(glfw.VISIBLE, False)
self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
self._attach_glfw_callbacks()
self.make_context_current()
# Adjust window.
self.set_vsync(False)
self.set_window_size(window_width, window_height)
if not self._deferred_show:
glfw.show_window(self._glfw_window)
def close(self):
if self._drawing_frame:
self.end_frame()
if self._glfw_window is not None:
glfw.destroy_window(self._glfw_window)
self._glfw_window = None
#glfw.terminate() # Commented out to play it nice with other glfw clients.
def __del__(self):
try:
self.close()
except:
pass
@property
def window_width(self):
return self.content_width
@property
def window_height(self):
return self.content_height + self.title_bar_height
@property
def content_width(self):
width, _height = glfw.get_window_size(self._glfw_window)
return width
@property
def content_height(self):
_width, height = glfw.get_window_size(self._glfw_window)
return height
@property
def title_bar_height(self):
_left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
return top
@property
def monitor_width(self):
_, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
return width
@property
def monitor_height(self):
_, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
return height
@property
def frame_delta(self):
return self._frame_delta
def set_title(self, title):
glfw.set_window_title(self._glfw_window, title)
def set_window_size(self, width, height):
width = min(width, self.monitor_width)
height = min(height, self.monitor_height)
glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
if width == self.monitor_width and height == self.monitor_height:
self.maximize()
def set_content_size(self, width, height):
self.set_window_size(width, height + self.title_bar_height)
def maximize(self):
glfw.maximize_window(self._glfw_window)
def set_position(self, x, y):
glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
def center(self):
self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
def set_vsync(self, vsync):
vsync = bool(vsync)
if vsync != self._vsync:
glfw.swap_interval(1 if vsync else 0)
self._vsync = vsync
def set_fps_limit(self, fps_limit):
self._fps_limit = int(fps_limit)
def should_close(self):
return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
def skip_frame(self):
self.skip_frames(1)
def skip_frames(self, num): # Do not update window for the next N frames.
self._skip_frames = max(self._skip_frames, int(num))
def is_skipping_frames(self):
return self._skip_frames > 0
def capture_next_frame(self):
self._capture_next_frame = True
def pop_captured_frame(self):
frame = self._captured_frame
self._captured_frame = None
return frame
def pop_drag_and_drop_paths(self):
paths = self._drag_and_drop_paths
self._drag_and_drop_paths = None
return paths
def draw_frame(self): # To be overridden by subclass.
self.begin_frame()
# Rendering code goes here.
self.end_frame()
def make_context_current(self):
if self._glfw_window is not None:
glfw.make_context_current(self._glfw_window)
def begin_frame(self):
# End previous frame.
if self._drawing_frame:
self.end_frame()
# Apply FPS limit.
if self._frame_start_time is not None and self._fps_limit is not None:
delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
if delay > 0:
time.sleep(delay)
cur_time = time.perf_counter()
if self._frame_start_time is not None:
self._frame_delta = cur_time - self._frame_start_time
self._frame_start_time = cur_time
# Process events.
glfw.poll_events()
# Begin frame.
self._drawing_frame = True
self.make_context_current()
# Initialize GL state.
gl.glViewport(0, 0, self.content_width, self.content_height)
gl.glMatrixMode(gl.GL_PROJECTION)
gl.glLoadIdentity()
gl.glTranslate(-1, 1, 0)
gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
gl.glMatrixMode(gl.GL_MODELVIEW)
gl.glLoadIdentity()
gl.glEnable(gl.GL_BLEND)
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
# Clear.
gl.glClearColor(0, 0, 0, 1)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
def end_frame(self):
assert self._drawing_frame
self._drawing_frame = False
# Skip frames if requested.
if self._skip_frames > 0:
self._skip_frames -= 1
return
# Capture frame if requested.
if self._capture_next_frame:
self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
self._capture_next_frame = False
# Update window.
if self._deferred_show:
glfw.show_window(self._glfw_window)
self._deferred_show = False
glfw.swap_buffers(self._glfw_window)
def _attach_glfw_callbacks(self):
glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
self._esc_pressed = True
def _glfw_drop_callback(self, _window, paths):
self._drag_and_drop_paths = paths
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import contextlib
import imgui
#----------------------------------------------------------------------------
def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
s = imgui.get_style()
s.window_padding = [spacing, spacing]
s.item_spacing = [spacing, spacing]
s.item_inner_spacing = [spacing, spacing]
s.columns_min_spacing = spacing
s.indent_spacing = indent
s.scrollbar_size = scrollbar
s.frame_padding = [4, 3]
s.window_border_size = 1
s.child_border_size = 1
s.popup_border_size = 1
s.frame_border_size = 1
s.window_rounding = 0
s.child_rounding = 0
s.popup_rounding = 3
s.frame_rounding = 3
s.scrollbar_rounding = 3
s.grab_rounding = 3
getattr(imgui, f'style_colors_{color_scheme}')(s)
c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
#----------------------------------------------------------------------------
@contextlib.contextmanager
def grayed_out(cond=True):
if cond:
s = imgui.get_style()
text = s.colors[imgui.COLOR_TEXT_DISABLED]
grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
imgui.push_style_color(imgui.COLOR_TEXT, *text)
imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_BUTTON, *back)
imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_HEADER, *back)
imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
yield
imgui.pop_style_color(14)
else:
yield
#----------------------------------------------------------------------------
@contextlib.contextmanager
def item_width(width=None):
if width is not None:
imgui.push_item_width(width)
yield
imgui.pop_item_width()
else:
yield
#----------------------------------------------------------------------------
def scoped_by_object_id(method):
def decorator(self, *args, **kwargs):
imgui.push_id(str(id(self)))
res = method(self, *args, **kwargs)
imgui.pop_id()
return res
return decorator
#----------------------------------------------------------------------------
def button(label, width=0, enabled=True):
with grayed_out(not enabled):
clicked = imgui.button(label, width=width)
clicked = clicked and enabled
return clicked
#----------------------------------------------------------------------------
def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
expanded = False
if show:
if default:
flags |= imgui.TREE_NODE_DEFAULT_OPEN
if not enabled:
flags |= imgui.TREE_NODE_LEAF
with grayed_out(not enabled):
expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
expanded = expanded and enabled
return expanded, visible
#----------------------------------------------------------------------------
def popup_button(label, width=0, enabled=True):
if button(label, width, enabled):
imgui.open_popup(label)
opened = imgui.begin_popup(label)
return opened
#----------------------------------------------------------------------------
def input_text(label, value, buffer_length, flags, width=None, help_text=''):
old_value = value
color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
if value == '':
color[-1] *= 0.5
with item_width(width):
imgui.push_style_color(imgui.COLOR_TEXT, *color)
value = value if value != '' else help_text
changed, value = imgui.input_text(label, value, buffer_length, flags)
value = value if value != help_text else ''
imgui.pop_style_color(1)
if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
changed = (value != old_value)
return changed, value
#----------------------------------------------------------------------------
def drag_previous_control(enabled=True):
dragging = False
dx = 0
dy = 0
if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
if enabled:
dragging = True
dx, dy = imgui.get_mouse_drag_delta()
imgui.reset_mouse_drag_delta()
imgui.end_drag_drop_source()
return dragging, dx, dy
#----------------------------------------------------------------------------
def drag_button(label, width=0, enabled=True):
clicked = button(label, width=width, enabled=enabled)
dragging, dx, dy = drag_previous_control(enabled=enabled)
return clicked, dragging, dx, dy
#----------------------------------------------------------------------------
def drag_hidden_window(label, x, y, width, height, enabled=True):
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
imgui.set_next_window_position(x, y)
imgui.set_next_window_size(width, height)
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
dragging, dx, dy = drag_previous_control(enabled=enabled)
imgui.end()
imgui.pop_style_color(2)
return dragging, dx, dy
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import imgui
import imgui.integrations.glfw
from . import glfw_window
from . import imgui_utils
from . import text_utils
#----------------------------------------------------------------------------
class ImguiWindow(glfw_window.GlfwWindow):
def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
if font is None:
font = text_utils.get_default_font()
font_sizes = {int(size) for size in font_sizes}
super().__init__(title=title, **glfw_kwargs)
# Init fields.
self._imgui_context = None
self._imgui_renderer = None
self._imgui_fonts = None
self._cur_font_size = max(font_sizes)
# Delete leftover imgui.ini to avoid unexpected behavior.
if os.path.isfile('imgui.ini'):
os.remove('imgui.ini')
# Init ImGui.
self._imgui_context = imgui.create_context()
self._imgui_renderer = _GlfwRenderer(self._glfw_window)
self._attach_glfw_callbacks()
imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
self._imgui_renderer.refresh_font_texture()
def close(self):
self.make_context_current()
self._imgui_fonts = None
if self._imgui_renderer is not None:
self._imgui_renderer.shutdown()
self._imgui_renderer = None
if self._imgui_context is not None:
#imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
self._imgui_context = None
super().close()
def _glfw_key_callback(self, *args):
super()._glfw_key_callback(*args)
self._imgui_renderer.keyboard_callback(*args)
@property
def font_size(self):
return self._cur_font_size
@property
def spacing(self):
return round(self._cur_font_size * 0.4)
def set_font_size(self, target): # Applied on next frame.
self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
def begin_frame(self):
# Begin glfw frame.
super().begin_frame()
# Process imgui events.
self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
if self.content_width > 0 and self.content_height > 0:
self._imgui_renderer.process_inputs()
# Begin imgui frame.
imgui.new_frame()
imgui.push_font(self._imgui_fonts[self._cur_font_size])
imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
def end_frame(self):
imgui.pop_font()
imgui.render()
imgui.end_frame()
self._imgui_renderer.render(imgui.get_draw_data())
super().end_frame()
#----------------------------------------------------------------------------
# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mouse_wheel_multiplier = 1
def scroll_callback(self, window, x_offset, y_offset):
self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import functools
from typing import Optional
import dnnlib
import numpy as np
import PIL.Image
import PIL.ImageFont
import scipy.ndimage
from . import gl_utils
#----------------------------------------------------------------------------
def get_default_font():
url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
return dnnlib.util.open_url(url, return_filename=True)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=None)
def get_pil_font(font=None, size=32):
if font is None:
font = get_default_font()
return PIL.ImageFont.truetype(font=font, size=size)
#----------------------------------------------------------------------------
def get_array(string, *, dropshadow_radius: int=None, **kwargs):
if dropshadow_radius is not None:
offset_x = int(np.ceil(dropshadow_radius*2/3))
offset_y = int(np.ceil(dropshadow_radius*2/3))
return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
else:
return _get_array_priv(string, **kwargs)
@functools.lru_cache(maxsize=10000)
def _get_array_priv(
string: str, *,
size: int = 32,
max_width: Optional[int]=None,
max_height: Optional[int]=None,
min_size=10,
shrink_coef=0.8,
dropshadow_radius: int=None,
offset_x: int=None,
offset_y: int=None,
**kwargs
):
cur_size = size
array = None
while True:
if dropshadow_radius is not None:
# separate implementation for dropshadow text rendering
array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
else:
array = _get_array_impl(string, size=cur_size, **kwargs)
height, width, _ = array.shape
if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
break
cur_size = max(int(cur_size * shrink_coef), min_size)
return array
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
pil_font = get_pil_font(font=font, size=size)
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
width = max(line.shape[1] for line in lines)
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
line_spacing = line_pad if line_pad is not None else size // 2
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
mask = np.concatenate(lines, axis=0)
alpha = mask
if outline > 0:
mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
alpha = mask.astype(np.float32) / 255
alpha = scipy.ndimage.gaussian_filter(alpha, outline)
alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
alpha = np.maximum(alpha, mask)
return np.stack([mask, alpha], axis=-1)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
assert (offset_x > 0) and (offset_y > 0)
pil_font = get_pil_font(font=font, size=size)
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
width = max(line.shape[1] for line in lines)
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
line_spacing = line_pad if line_pad is not None else size // 2
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
mask = np.concatenate(lines, axis=0)
alpha = mask
mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
alpha = mask.astype(np.float32) / 255
alpha = scipy.ndimage.gaussian_filter(alpha, radius)
alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
alpha = np.maximum(alpha, mask)
return np.stack([mask, alpha], axis=-1)
#----------------------------------------------------------------------------
@functools.lru_cache(maxsize=10000)
def get_texture(string, bilinear=True, mipmap=True, **kwargs):
return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Converting legacy network pickle into the new format."""
import click
import pickle
import re
import copy
import numpy as np
import torch
import dnnlib
from torch_utils import misc
#----------------------------------------------------------------------------
def load_network_pkl(f, force_fp16=False):
data = _LegacyUnpickler(f).load()
# Legacy TensorFlow pickle => convert.
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
tf_G, tf_D, tf_Gs = data
G = convert_tf_generator(tf_G)
D = convert_tf_discriminator(tf_D)
G_ema = convert_tf_generator(tf_Gs)
data = dict(G=G, D=D, G_ema=G_ema)
# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None
# Validate contents.
assert isinstance(data['G'], torch.nn.Module)
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G_ema'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
# Force FP16.
if force_fp16:
for key in ['G', 'D', 'G_ema']:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
fp16_kwargs.num_fp16_res = 4
fp16_kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
return data
#----------------------------------------------------------------------------
class _TFNetworkStub(dnnlib.EasyDict):
pass
class _LegacyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'dnnlib.tflib.network' and name == 'Network':
return _TFNetworkStub
return super().find_class(module, name)
#----------------------------------------------------------------------------
def _collect_tf_params(tf_net):
# pylint: disable=protected-access
tf_params = dict()
def recurse(prefix, tf_net):
for name, value in tf_net.variables:
tf_params[prefix + name] = value
for name, comp in tf_net.components.items():
recurse(prefix + name + '/', comp)
recurse('', tf_net)
return tf_params
#----------------------------------------------------------------------------
def _populate_module_params(module, *patterns):
for name, tensor in misc.named_params_and_buffers(module):
found = False
value = None
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
match = re.fullmatch(pattern, name)
if match:
found = True
if value_fn is not None:
value = value_fn(*match.groups())
break
try:
assert found
if value is not None:
tensor.copy_(torch.from_numpy(np.array(value)))
except:
print(name, list(tensor.shape))
raise
#----------------------------------------------------------------------------
def convert_tf_generator(tf_G):
if tf_G.version < 4:
raise ValueError('TensorFlow pickle version too low')
# Collect kwargs.
tf_kwargs = tf_G.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None, none=None):
known_kwargs.add(tf_name)
val = tf_kwargs.get(tf_name, default)
return val if val is not None else none
# Convert kwargs.
from training import networks_stylegan2
network_class = networks_stylegan2.Generator
kwargs = dnnlib.EasyDict(
z_dim = kwarg('latent_size', 512),
c_dim = kwarg('label_size', 0),
w_dim = kwarg('dlatent_size', 512),
img_resolution = kwarg('resolution', 1024),
img_channels = kwarg('num_channels', 3),
channel_base = kwarg('fmap_base', 16384) * 2,
channel_max = kwarg('fmap_max', 512),
num_fp16_res = kwarg('num_fp16_res', 0),
conv_clamp = kwarg('conv_clamp', None),
architecture = kwarg('architecture', 'skip'),
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
use_noise = kwarg('use_noise', True),
activation = kwarg('nonlinearity', 'lrelu'),
mapping_kwargs = dnnlib.EasyDict(
num_layers = kwarg('mapping_layers', 8),
embed_features = kwarg('label_fmaps', None),
layer_features = kwarg('mapping_fmaps', None),
activation = kwarg('mapping_nonlinearity', 'lrelu'),
lr_multiplier = kwarg('mapping_lrmul', 0.01),
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
),
)
# Check for unknown kwargs.
kwarg('truncation_psi')
kwarg('truncation_cutoff')
kwarg('style_mixing_prob')
kwarg('structure')
kwarg('conditioning')
kwarg('fused_modconv')
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_G)
for name, value in list(tf_params.items()):
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
kwargs.synthesis.kwargs.architecture = 'orig'
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
G = network_class(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
# pylint: disable=f-string-without-interpolation
_populate_module_params(G,
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
r'.*\.resample_filter', None,
r'.*\.act_filter', None,
)
return G
#----------------------------------------------------------------------------
def convert_tf_discriminator(tf_D):
if tf_D.version < 4:
raise ValueError('TensorFlow pickle version too low')
# Collect kwargs.
tf_kwargs = tf_D.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None):
known_kwargs.add(tf_name)
return tf_kwargs.get(tf_name, default)
# Convert kwargs.
kwargs = dnnlib.EasyDict(
c_dim = kwarg('label_size', 0),
img_resolution = kwarg('resolution', 1024),
img_channels = kwarg('num_channels', 3),
architecture = kwarg('architecture', 'resnet'),
channel_base = kwarg('fmap_base', 16384) * 2,
channel_max = kwarg('fmap_max', 512),
num_fp16_res = kwarg('num_fp16_res', 0),
conv_clamp = kwarg('conv_clamp', None),
cmap_dim = kwarg('mapping_fmaps', None),
block_kwargs = dnnlib.EasyDict(
activation = kwarg('nonlinearity', 'lrelu'),
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
freeze_layers = kwarg('freeze_layers', 0),
),
mapping_kwargs = dnnlib.EasyDict(
num_layers = kwarg('mapping_layers', 0),
embed_features = kwarg('mapping_fmaps', None),
layer_features = kwarg('mapping_fmaps', None),
activation = kwarg('nonlinearity', 'lrelu'),
lr_multiplier = kwarg('mapping_lrmul', 0.1),
),
epilogue_kwargs = dnnlib.EasyDict(
mbstd_group_size = kwarg('mbstd_group_size', None),
mbstd_num_channels = kwarg('mbstd_num_features', 1),
activation = kwarg('nonlinearity', 'lrelu'),
),
)
# Check for unknown kwargs.
kwarg('structure')
kwarg('conditioning')
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_D)
for name, value in list(tf_params.items()):
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
kwargs.architecture = 'orig'
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
from training import networks_stylegan2
D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
# pylint: disable=f-string-without-interpolation
_populate_module_params(D,
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
r'.*\.resample_filter', None,
)
return D
#----------------------------------------------------------------------------
@click.command()
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
def convert_network_pickle(source, dest, force_fp16):
"""Convert legacy network pickle into the native PyTorch format.
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
Example:
\b
python legacy.py \\
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
--dest=stylegan2-cat-config-f.pkl
"""
print(f'Loading "{source}"...')
with dnnlib.util.open_url(source) as f:
data = load_network_pkl(f, force_fp16=force_fp16)
print(f'Saving "{dest}"...')
with open(dest, 'wb') as f:
pickle.dump(data, f)
print('Done.')
#----------------------------------------------------------------------------
if __name__ == "__main__":
convert_network_pickle() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
"Alias-Free Generative Adversarial Networks"."""
import copy
import numpy as np
import torch
import torch.fft
from torch_utils.ops import upfirdn2d
from . import metric_utils
#----------------------------------------------------------------------------
# Utilities.
def sinc(x):
y = (x * np.pi).abs()
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
return torch.where(y < 1e-30, torch.ones_like(x), z)
def lanczos_window(x, a):
x = x.abs() / a
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
def rotation_matrix(angle):
angle = torch.as_tensor(angle).to(torch.float32)
mat = torch.eye(3, device=angle.device)
mat[0, 0] = angle.cos()
mat[0, 1] = angle.sin()
mat[1, 0] = -angle.sin()
mat[1, 1] = angle.cos()
return mat
#----------------------------------------------------------------------------
# Apply integer translation to a batch of 2D images. Corresponds to the
# operator T_x in Appendix E.1.
def apply_integer_translation(x, tx, ty):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.round().to(torch.int64)
iy = ty.round().to(torch.int64)
z = torch.zeros_like(x)
m = torch.zeros_like(x)
if abs(ix) < W and abs(iy) < H:
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
return z, m
#----------------------------------------------------------------------------
# Apply integer translation to a batch of 2D images. Corresponds to the
# operator T_x in Appendix E.2.
def apply_fractional_translation(x, tx, ty, a=3):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.floor().to(torch.int64)
iy = ty.floor().to(torch.int64)
fx = tx - ix
fy = ty - iy
b = a - 1
z = torch.zeros_like(x)
zx0 = max(ix - b, 0)
zy0 = max(iy - b, 0)
zx1 = min(ix + a, 0) + W
zy1 = min(iy + a, 0) + H
if zx0 < zx1 and zy0 < zy1:
taps = torch.arange(a * 2, device=x.device) - b
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
y = x
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
z[:, :, zy0:zy1, zx0:zx1] = y
m = torch.zeros_like(x)
mx0 = max(ix + a, 0)
my0 = max(iy + a, 0)
mx1 = min(ix - b, 0) + W
my1 = min(iy - b, 0) + H
if mx0 < mx1 and my0 < my1:
m[:, :, my0:my1, mx0:mx1] = 1
return z, m
#----------------------------------------------------------------------------
# Construct an oriented low-pass filter that applies the appropriate
# bandlimit with respect to the input and output of the given affine 2D
# image transformation.
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
assert a <= amax < aflt
mat = torch.as_tensor(mat).to(torch.float32)
# Construct 2D filter taps in input & output coordinate spaces.
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
yi, xi = torch.meshgrid(taps, taps)
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
# Convolution of two oriented 2D sinc filters.
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
# Convolution of two oriented 2D Lanczos windows.
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
# Construct windowed FIR filter.
f = f * w
# Finalize.
c = (aflt - amax) * up
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
return f
#----------------------------------------------------------------------------
# Apply the given affine transformation to a batch of 2D images.
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
_N, _C, H, W = x.shape
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
# Construct filter.
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
p = f.shape[0] // 2
# Construct sampling grid.
theta = mat.inverse()
theta[:2, 2] *= 2
theta[0, 2] += 1 / up / W
theta[1, 2] += 1 / up / H
theta[0, :] *= W / (W + p / up * 2)
theta[1, :] *= H / (H + p / up * 2)
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
# Resample image.
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
# Form mask.
m = torch.zeros_like(y)
c = p * 2 + 1
m[:, :, c:-c, c:-c] = 1
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
return z, m
#----------------------------------------------------------------------------
# Apply fractional rotation to a batch of 2D images. Corresponds to the
# operator R_\alpha in Appendix E.3.
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(angle)
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
#----------------------------------------------------------------------------
# Modify the frequency content of a batch of 2D images as if they had undergo
# fractional rotation -- but without actually rotating them. Corresponds to
# the operator R^*_\alpha in Appendix E.3.
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(-angle)
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
y = upfirdn2d.filter2d(x=x, f=f)
m = torch.zeros_like(y)
c = f.shape[0] // 2
m[:, :, c:-c, c:-c] = 1
return y, m
#----------------------------------------------------------------------------
# Compute the selected equivariance metrics for the given generator.
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
assert compute_eqt_int or compute_eqt_frac or compute_eqr
# Setup generator and labels.
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
I = torch.eye(3, device=opts.device)
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
if M is None:
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
# Sampling loop.
sums = None
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
progress.update(batch_start)
s = []
# Randomize noise buffers, if any.
for name, buf in G.named_buffers():
if name.endswith('.noise_const'):
buf.copy_(torch.randn_like(buf))
# Run mapping network.
z = torch.randn([batch_size, G.z_dim], device=opts.device)
c = next(c_iter)
ws = G.mapping(z=z, c=c)
# Generate reference image.
M[:] = I
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
# Integer translation (EQ-T).
if compute_eqt_int:
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
t = (t * G.img_resolution).round() / G.img_resolution
M[:] = I
M[:2, 2] = -t
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, mask = apply_integer_translation(orig, t[0], t[1])
s += [(ref - img).square() * mask, mask]
# Fractional translation (EQ-T_frac).
if compute_eqt_frac:
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
M[:] = I
M[:2, 2] = -t
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, mask = apply_fractional_translation(orig, t[0], t[1])
s += [(ref - img).square() * mask, mask]
# Rotation (EQ-R).
if compute_eqr:
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
M[:] = rotation_matrix(-angle)
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, ref_mask = apply_fractional_rotation(orig, angle)
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
mask = ref_mask * pseudo_mask
s += [(ref - pseudo).square() * mask, mask]
# Accumulate results.
s = torch.stack([x.to(torch.float64).sum() for x in s])
sums = sums + s if sums is not None else s
progress.update(num_samples)
# Compute PSNRs.
if opts.num_gpus > 1:
torch.distributed.all_reduce(sums)
sums = sums.cpu()
mses = sums[0::2] / sums[1::2]
psnrs = np.log10(2) * 20 - mses.log10() * 10
psnrs = tuple(psnrs.numpy())
return psnrs[0] if len(psnrs) == 1 else psnrs
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Frechet Inception Distance (FID) from the paper
"GANs trained by a two time-scale update rule converge to a local Nash
equilibrium". Matches the original implementation by Heusel et al. at
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
import numpy as np
import scipy.linalg
from . import metric_utils
#----------------------------------------------------------------------------
def compute_fid(opts, max_real, num_gen):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
if opts.rank != 0:
return float('nan')
m = np.square(mu_gen - mu_real).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
return float(fid)
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Inception Score (IS) from the paper "Improved techniques for training
GANs". Matches the original implementation by Salimans et al. at
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
import numpy as np
from . import metric_utils
#----------------------------------------------------------------------------
def compute_is(opts, num_gen, num_splits):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
gen_probs = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
capture_all=True, max_items=num_gen).get_all()
if opts.rank != 0:
return float('nan'), float('nan')
scores = []
for i in range(num_splits):
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
kl = np.mean(np.sum(kl, axis=1))
scores.append(np.exp(kl))
return float(np.mean(scores)), float(np.std(scores))
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
GANs". Matches the original implementation by Binkowski et al. at
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
import numpy as np
from . import metric_utils
#----------------------------------------------------------------------------
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
real_features = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
gen_features = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
if opts.rank != 0:
return float('nan')
n = real_features.shape[1]
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
t = 0
for _subset_idx in range(num_subsets):
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
b = (x @ y.T / n + 1) ** 3
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
kid = t / num_subsets / m
return float(kid)
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Main API for computing and reporting quality metrics."""
import os
import time
import json
import torch
import dnnlib
from . import metric_utils
from . import frechet_inception_distance
from . import kernel_inception_distance
from . import precision_recall
from . import perceptual_path_length
from . import inception_score
from . import equivariance
#----------------------------------------------------------------------------
_metric_dict = dict() # name => fn
def register_metric(fn):
assert callable(fn)
_metric_dict[fn.__name__] = fn
return fn
def is_valid_metric(metric):
return metric in _metric_dict
def list_valid_metrics():
return list(_metric_dict.keys())
#----------------------------------------------------------------------------
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
assert is_valid_metric(metric)
opts = metric_utils.MetricOptions(**kwargs)
# Calculate.
start_time = time.time()
results = _metric_dict[metric](opts)
total_time = time.time() - start_time
# Broadcast results.
for key, value in list(results.items()):
if opts.num_gpus > 1:
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
torch.distributed.broadcast(tensor=value, src=0)
value = float(value.cpu())
results[key] = value
# Decorate with metadata.
return dnnlib.EasyDict(
results = dnnlib.EasyDict(results),
metric = metric,
total_time = total_time,
total_time_str = dnnlib.util.format_time(total_time),
num_gpus = opts.num_gpus,
)
#----------------------------------------------------------------------------
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
metric = result_dict['metric']
assert is_valid_metric(metric)
if run_dir is not None and snapshot_pkl is not None:
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
print(jsonl_line)
if run_dir is not None and os.path.isdir(run_dir):
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
f.write(jsonl_line + '\n')
#----------------------------------------------------------------------------
# Recommended metrics.
@register_metric
def fid50k_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
return dict(fid50k_full=fid)
@register_metric
def kid50k_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
return dict(kid50k_full=kid)
@register_metric
def pr50k3_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
@register_metric
def ppl2_wend(opts):
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
return dict(ppl2_wend=ppl)
@register_metric
def eqt50k_int(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
return dict(eqt50k_int=psnr)
@register_metric
def eqt50k_frac(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
return dict(eqt50k_frac=psnr)
@register_metric
def eqr50k(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
return dict(eqr50k=psnr)
#----------------------------------------------------------------------------
# Legacy metrics.
@register_metric
def fid50k(opts):
opts.dataset_kwargs.update(max_size=None)
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
return dict(fid50k=fid)
@register_metric
def kid50k(opts):
opts.dataset_kwargs.update(max_size=None)
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
return dict(kid50k=kid)
@register_metric
def pr50k3(opts):
opts.dataset_kwargs.update(max_size=None)
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
@register_metric
def is50k(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
return dict(is50k_mean=mean, is50k_std=std)
#----------------------------------------------------------------------------
This diff is collapsed.
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
Architecture for Generative Adversarial Networks". Matches the original
implementation by Karras et al. at
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
import copy
import numpy as np
import torch
from . import metric_utils
#----------------------------------------------------------------------------
# Spherical interpolation of a batch of vectors.
def slerp(a, b, t):
a = a / a.norm(dim=-1, keepdim=True)
b = b / b.norm(dim=-1, keepdim=True)
d = (a * b).sum(dim=-1, keepdim=True)
p = t * torch.acos(d)
c = b - d * a
c = c / c.norm(dim=-1, keepdim=True)
d = a * torch.cos(p) + c * torch.sin(p)
d = d / d.norm(dim=-1, keepdim=True)
return d
#----------------------------------------------------------------------------
class PPLSampler(torch.nn.Module):
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
assert space in ['z', 'w']
assert sampling in ['full', 'end']
super().__init__()
self.G = copy.deepcopy(G)
self.G_kwargs = G_kwargs
self.epsilon = epsilon
self.space = space
self.sampling = sampling
self.crop = crop
self.vgg16 = copy.deepcopy(vgg16)
def forward(self, c):
# Generate random latents and interpolation t-values.
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
# Interpolate in W or Z.
if self.space == 'w':
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
else: # space == 'z'
zt0 = slerp(z0, z1, t.unsqueeze(1))
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
# Randomize noise buffers.
for name, buf in self.G.named_buffers():
if name.endswith('.noise_const'):
buf.copy_(torch.randn_like(buf))
# Generate images.
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
# Center crop.
if self.crop:
assert img.shape[2] == img.shape[3]
c = img.shape[2] // 8
img = img[:, :, c*3 : c*7, c*2 : c*6]
# Downsample to 256x256.
factor = self.G.img_resolution // 256
if factor > 1:
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
# Scale dynamic range from [-1,1] to [0,255].
img = (img + 1) * (255 / 2)
if self.G.img_channels == 1:
img = img.repeat([1, 3, 1, 1])
# Evaluate differential LPIPS.
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
return dist
#----------------------------------------------------------------------------
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
# Setup sampler and labels.
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
sampler.eval().requires_grad_(False).to(opts.device)
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
# Sampling loop.
dist = []
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
progress.update(batch_start)
x = sampler(next(c_iter))
for src in range(opts.num_gpus):
y = x.clone()
if opts.num_gpus > 1:
torch.distributed.broadcast(y, src=src)
dist.append(y)
progress.update(num_samples)
# Compute PPL.
if opts.rank != 0:
return float('nan')
dist = torch.cat(dist)[:num_samples].cpu().numpy()
lo = np.percentile(dist, 1, interpolation='lower')
hi = np.percentile(dist, 99, interpolation='higher')
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
return float(ppl)
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
Metric for Assessing Generative Models". Matches the original implementation
by Kynkaanniemi et al. at
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
import torch
from . import metric_utils
#----------------------------------------------------------------------------
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
assert 0 <= rank < num_gpus
num_cols = col_features.shape[0]
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
dist_batches = []
for col_batch in col_batches[rank :: num_gpus]:
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
for src in range(num_gpus):
dist_broadcast = dist_batch.clone()
if num_gpus > 1:
torch.distributed.broadcast(dist_broadcast, src=src)
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
#----------------------------------------------------------------------------
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
detector_kwargs = dict(return_features=True)
real_features = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
gen_features = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
results = dict()
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
kth = []
for manifold_batch in manifold.split(row_batch_size):
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
kth = torch.cat(kth) if opts.rank == 0 else None
pred = []
for probes_batch in probes.split(row_batch_size):
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
return results['precision'], results['recall']
#----------------------------------------------------------------------------
# 模型名称
modelName=Stylegan3_Pytorch
# 模型描述
modelDescription=Stylegan3_Pytorch是针对StyleGAN2的改进版,使得其具有高质量的Equivariance
# 应用场景
appScenario=推理,训练,预训练,cv,图像生成
# 框架类型
frameType=PyTorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment