Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DragGAN_pytorch
Commits
fba8bde8
Commit
fba8bde8
authored
Oct 29, 2024
by
bailuo
Browse files
update
parents
Pipeline
#1808
failed with stages
Changes
224
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3425 additions
and
0 deletions
+3425
-0
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.cpp
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.cpp
+24
-0
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.py
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.py
+61
-0
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d_kernel.cu
...human/pti/pti_models/e4e/stylegan2/op/upfirdn2d_kernel.cu
+273
-0
stylegan_human/pti/training/__init__.py
stylegan_human/pti/training/__init__.py
+0
-0
stylegan_human/pti/training/coaches/__init__.py
stylegan_human/pti/training/coaches/__init__.py
+0
-0
stylegan_human/pti/training/coaches/base_coach.py
stylegan_human/pti/training/coaches/base_coach.py
+150
-0
stylegan_human/pti/training/coaches/localitly_regulizer.py
stylegan_human/pti/training/coaches/localitly_regulizer.py
+63
-0
stylegan_human/pti/training/coaches/multi_id_coach.py
stylegan_human/pti/training/coaches/multi_id_coach.py
+79
-0
stylegan_human/pti/training/coaches/single_id_coach.py
stylegan_human/pti/training/coaches/single_id_coach.py
+89
-0
stylegan_human/pti/training/projectors/__init__.py
stylegan_human/pti/training/projectors/__init__.py
+0
-0
stylegan_human/pti/training/projectors/w_plus_projector.py
stylegan_human/pti/training/projectors/w_plus_projector.py
+147
-0
stylegan_human/pti/training/projectors/w_projector.py
stylegan_human/pti/training/projectors/w_projector.py
+144
-0
stylegan_human/run_pti.py
stylegan_human/run_pti.py
+52
-0
stylegan_human/style_mixing.py
stylegan_human/style_mixing.py
+108
-0
stylegan_human/stylemixing_video.py
stylegan_human/stylemixing_video.py
+154
-0
stylegan_human/torch_utils/__init__.py
stylegan_human/torch_utils/__init__.py
+11
-0
stylegan_human/torch_utils/custom_ops.py
stylegan_human/torch_utils/custom_ops.py
+239
-0
stylegan_human/torch_utils/misc.py
stylegan_human/torch_utils/misc.py
+264
-0
stylegan_human/torch_utils/models.py
stylegan_human/torch_utils/models.py
+757
-0
stylegan_human/torch_utils/models_face.py
stylegan_human/torch_utils/models_face.py
+810
-0
No files found.
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.cpp
0 → 100644
View file @
fba8bde8
#include <torch/extension.h>
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch
::
Tensor
upfirdn2d
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
CHECK_CUDA
(
input
);
CHECK_CUDA
(
kernel
);
return
upfirdn2d_op
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
);
}
\ No newline at end of file
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d.py
0 → 100644
View file @
fba8bde8
import
os
import
torch
from
torch.nn
import
functional
as
F
module_path
=
os
.
path
.
dirname
(
__file__
)
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
out
=
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
return
out
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)]
)
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
(
[
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
\ No newline at end of file
stylegan_human/pti/pti_models/e4e/stylegan2/op/upfirdn2d_kernel.cu
0 → 100644
View file @
fba8bde8
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
static
__host__
__device__
__forceinline__
int
floor_div
(
int
a
,
int
b
)
{
int
c
=
a
/
b
;
if
(
c
*
b
>
a
)
{
c
--
;
}
return
c
;
}
struct
UpFirDn2DKernelParams
{
int
up_x
;
int
up_y
;
int
down_x
;
int
down_y
;
int
pad_x0
;
int
pad_x1
;
int
pad_y0
;
int
pad_y1
;
int
major_dim
;
int
in_h
;
int
in_w
;
int
minor_dim
;
int
kernel_h
;
int
kernel_w
;
int
out_h
;
int
out_w
;
int
loop_major
;
int
loop_x
;
};
template
<
typename
scalar_t
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
kernel_h
,
int
kernel_w
,
int
tile_out_h
,
int
tile_out_w
>
__global__
void
upfirdn2d_kernel
(
scalar_t
*
out
,
const
scalar_t
*
input
,
const
scalar_t
*
kernel
,
const
UpFirDn2DKernelParams
p
)
{
const
int
tile_in_h
=
((
tile_out_h
-
1
)
*
down_y
+
kernel_h
-
1
)
/
up_y
+
1
;
const
int
tile_in_w
=
((
tile_out_w
-
1
)
*
down_x
+
kernel_w
-
1
)
/
up_x
+
1
;
__shared__
volatile
float
sk
[
kernel_h
][
kernel_w
];
__shared__
volatile
float
sx
[
tile_in_h
][
tile_in_w
];
int
minor_idx
=
blockIdx
.
x
;
int
tile_out_y
=
minor_idx
/
p
.
minor_dim
;
minor_idx
-=
tile_out_y
*
p
.
minor_dim
;
tile_out_y
*=
tile_out_h
;
int
tile_out_x_base
=
blockIdx
.
y
*
p
.
loop_x
*
tile_out_w
;
int
major_idx_base
=
blockIdx
.
z
*
p
.
loop_major
;
if
(
tile_out_x_base
>=
p
.
out_w
|
tile_out_y
>=
p
.
out_h
|
major_idx_base
>=
p
.
major_dim
)
{
return
;
}
for
(
int
tap_idx
=
threadIdx
.
x
;
tap_idx
<
kernel_h
*
kernel_w
;
tap_idx
+=
blockDim
.
x
)
{
int
ky
=
tap_idx
/
kernel_w
;
int
kx
=
tap_idx
-
ky
*
kernel_w
;
scalar_t
v
=
0.0
;
if
(
kx
<
p
.
kernel_w
&
ky
<
p
.
kernel_h
)
{
v
=
kernel
[(
p
.
kernel_h
-
1
-
ky
)
*
p
.
kernel_w
+
(
p
.
kernel_w
-
1
-
kx
)];
}
sk
[
ky
][
kx
]
=
v
;
}
for
(
int
loop_major
=
0
,
major_idx
=
major_idx_base
;
loop_major
<
p
.
loop_major
&
major_idx
<
p
.
major_dim
;
loop_major
++
,
major_idx
++
)
{
for
(
int
loop_x
=
0
,
tile_out_x
=
tile_out_x_base
;
loop_x
<
p
.
loop_x
&
tile_out_x
<
p
.
out_w
;
loop_x
++
,
tile_out_x
+=
tile_out_w
)
{
int
tile_mid_x
=
tile_out_x
*
down_x
+
up_x
-
1
-
p
.
pad_x0
;
int
tile_mid_y
=
tile_out_y
*
down_y
+
up_y
-
1
-
p
.
pad_y0
;
int
tile_in_x
=
floor_div
(
tile_mid_x
,
up_x
);
int
tile_in_y
=
floor_div
(
tile_mid_y
,
up_y
);
__syncthreads
();
for
(
int
in_idx
=
threadIdx
.
x
;
in_idx
<
tile_in_h
*
tile_in_w
;
in_idx
+=
blockDim
.
x
)
{
int
rel_in_y
=
in_idx
/
tile_in_w
;
int
rel_in_x
=
in_idx
-
rel_in_y
*
tile_in_w
;
int
in_x
=
rel_in_x
+
tile_in_x
;
int
in_y
=
rel_in_y
+
tile_in_y
;
scalar_t
v
=
0.0
;
if
(
in_x
>=
0
&
in_y
>=
0
&
in_x
<
p
.
in_w
&
in_y
<
p
.
in_h
)
{
v
=
input
[((
major_idx
*
p
.
in_h
+
in_y
)
*
p
.
in_w
+
in_x
)
*
p
.
minor_dim
+
minor_idx
];
}
sx
[
rel_in_y
][
rel_in_x
]
=
v
;
}
__syncthreads
();
for
(
int
out_idx
=
threadIdx
.
x
;
out_idx
<
tile_out_h
*
tile_out_w
;
out_idx
+=
blockDim
.
x
)
{
int
rel_out_y
=
out_idx
/
tile_out_w
;
int
rel_out_x
=
out_idx
-
rel_out_y
*
tile_out_w
;
int
out_x
=
rel_out_x
+
tile_out_x
;
int
out_y
=
rel_out_y
+
tile_out_y
;
int
mid_x
=
tile_mid_x
+
rel_out_x
*
down_x
;
int
mid_y
=
tile_mid_y
+
rel_out_y
*
down_y
;
int
in_x
=
floor_div
(
mid_x
,
up_x
);
int
in_y
=
floor_div
(
mid_y
,
up_y
);
int
rel_in_x
=
in_x
-
tile_in_x
;
int
rel_in_y
=
in_y
-
tile_in_y
;
int
kernel_x
=
(
in_x
+
1
)
*
up_x
-
mid_x
-
1
;
int
kernel_y
=
(
in_y
+
1
)
*
up_y
-
mid_y
-
1
;
scalar_t
v
=
0.0
;
#pragma unroll
for
(
int
y
=
0
;
y
<
kernel_h
/
up_y
;
y
++
)
#pragma unroll
for
(
int
x
=
0
;
x
<
kernel_w
/
up_x
;
x
++
)
v
+=
sx
[
rel_in_y
+
y
][
rel_in_x
+
x
]
*
sk
[
kernel_y
+
y
*
up_y
][
kernel_x
+
x
*
up_x
];
if
(
out_x
<
p
.
out_w
&
out_y
<
p
.
out_h
)
{
out
[((
major_idx
*
p
.
out_h
+
out_y
)
*
p
.
out_w
+
out_x
)
*
p
.
minor_dim
+
minor_idx
]
=
v
;
}
}
}
}
}
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
int
curDevice
=
-
1
;
cudaGetDevice
(
&
curDevice
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
curDevice
);
UpFirDn2DKernelParams
p
;
auto
x
=
input
.
contiguous
();
auto
k
=
kernel
.
contiguous
();
p
.
major_dim
=
x
.
size
(
0
);
p
.
in_h
=
x
.
size
(
1
);
p
.
in_w
=
x
.
size
(
2
);
p
.
minor_dim
=
x
.
size
(
3
);
p
.
kernel_h
=
k
.
size
(
0
);
p
.
kernel_w
=
k
.
size
(
1
);
p
.
up_x
=
up_x
;
p
.
up_y
=
up_y
;
p
.
down_x
=
down_x
;
p
.
down_y
=
down_y
;
p
.
pad_x0
=
pad_x0
;
p
.
pad_x1
=
pad_x1
;
p
.
pad_y0
=
pad_y0
;
p
.
pad_y1
=
pad_y1
;
p
.
out_h
=
(
p
.
in_h
*
p
.
up_y
+
p
.
pad_y0
+
p
.
pad_y1
-
p
.
kernel_h
+
p
.
down_y
)
/
p
.
down_y
;
p
.
out_w
=
(
p
.
in_w
*
p
.
up_x
+
p
.
pad_x0
+
p
.
pad_x1
-
p
.
kernel_w
+
p
.
down_x
)
/
p
.
down_x
;
auto
out
=
at
::
empty
({
p
.
major_dim
,
p
.
out_h
,
p
.
out_w
,
p
.
minor_dim
},
x
.
options
());
int
mode
=
-
1
;
int
tile_out_h
;
int
tile_out_w
;
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
1
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
3
&&
p
.
kernel_w
<=
3
)
{
mode
=
2
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
3
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
4
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
5
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
6
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
dim3
block_size
;
dim3
grid_size
;
if
(
tile_out_h
>
0
&&
tile_out_w
)
{
p
.
loop_major
=
(
p
.
major_dim
-
1
)
/
16384
+
1
;
p
.
loop_x
=
1
;
block_size
=
dim3
(
32
*
8
,
1
,
1
);
grid_size
=
dim3
(((
p
.
out_h
-
1
)
/
tile_out_h
+
1
)
*
p
.
minor_dim
,
(
p
.
out_w
-
1
)
/
(
p
.
loop_x
*
tile_out_w
)
+
1
,
(
p
.
major_dim
-
1
)
/
p
.
loop_major
+
1
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
x
.
scalar_type
(),
"upfirdn2d_cuda"
,
[
&
]
{
switch
(
mode
)
{
case
1
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
4
,
4
,
16
,
64
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
2
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
3
,
3
,
16
,
64
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
3
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
4
,
4
,
16
,
64
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
4
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
2
,
2
,
16
,
64
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
5
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
6
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
}
});
return
out
;
}
\ No newline at end of file
stylegan_human/pti/training/__init__.py
0 → 100644
View file @
fba8bde8
stylegan_human/pti/training/coaches/__init__.py
0 → 100644
View file @
fba8bde8
stylegan_human/pti/training/coaches/base_coach.py
0 → 100644
View file @
fba8bde8
import
abc
import
os
import
pickle
from
argparse
import
Namespace
import
wandb
import
os.path
from
.localitly_regulizer
import
Space_Regulizer
,
l2_loss
import
torch
from
torchvision
import
transforms
from
lpips
import
LPIPS
from
pti.training.projectors
import
w_projector
from
pti.pti_configs
import
global_config
,
paths_config
,
hyperparameters
from
pti.pti_models.e4e.psp
import
pSp
from
utils.log_utils
import
log_image_from_w
from
utils.models_utils
import
toogle_grad
,
load_old_G
class
BaseCoach
:
def
__init__
(
self
,
data_loader
,
use_wandb
):
self
.
use_wandb
=
use_wandb
self
.
data_loader
=
data_loader
self
.
w_pivots
=
{}
self
.
image_counter
=
0
if
hyperparameters
.
first_inv_type
==
'w+'
:
self
.
initilize_e4e
()
self
.
e4e_image_transform
=
transforms
.
Compose
([
transforms
.
ToPILImage
(),
transforms
.
Resize
((
256
,
128
)),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])])
# Initialize loss
self
.
lpips_loss
=
LPIPS
(
net
=
hyperparameters
.
lpips_type
).
to
(
global_config
.
device
).
eval
()
self
.
restart_training
()
# Initialize checkpoint dir
self
.
checkpoint_dir
=
paths_config
.
checkpoints_dir
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
def
restart_training
(
self
):
# Initialize networks
self
.
G
=
load_old_G
()
toogle_grad
(
self
.
G
,
True
)
self
.
original_G
=
load_old_G
()
self
.
space_regulizer
=
Space_Regulizer
(
self
.
original_G
,
self
.
lpips_loss
)
self
.
optimizer
=
self
.
configure_optimizers
()
def
get_inversion
(
self
,
w_path_dir
,
image_name
,
image
):
embedding_dir
=
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
/
{
image_name
}
'
os
.
makedirs
(
embedding_dir
,
exist_ok
=
True
)
w_pivot
=
None
if
hyperparameters
.
use_last_w_pivots
:
w_pivot
=
self
.
load_inversions
(
w_path_dir
,
image_name
)
if
not
hyperparameters
.
use_last_w_pivots
or
w_pivot
is
None
:
w_pivot
=
self
.
calc_inversions
(
image
,
image_name
)
torch
.
save
(
w_pivot
,
f
'
{
embedding_dir
}
/0.pt'
)
w_pivot
=
w_pivot
.
to
(
global_config
.
device
)
return
w_pivot
def
load_inversions
(
self
,
w_path_dir
,
image_name
):
if
image_name
in
self
.
w_pivots
:
return
self
.
w_pivots
[
image_name
]
if
hyperparameters
.
first_inv_type
==
'w+'
:
w_potential_path
=
f
'
{
w_path_dir
}
/
{
paths_config
.
e4e_results_keyword
}
/
{
image_name
}
/0.pt'
else
:
w_potential_path
=
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
/
{
image_name
}
/0.pt'
if
not
os
.
path
.
isfile
(
w_potential_path
):
return
None
w
=
torch
.
load
(
w_potential_path
).
to
(
global_config
.
device
)
self
.
w_pivots
[
image_name
]
=
w
return
w
def
calc_inversions
(
self
,
image
,
image_name
):
if
hyperparameters
.
first_inv_type
==
'w+'
:
w
=
self
.
get_e4e_inversion
(
image
)
else
:
id_image
=
torch
.
squeeze
((
image
.
to
(
global_config
.
device
)
+
1
)
/
2
)
*
255
w
=
w_projector
.
project
(
self
.
G
,
id_image
,
device
=
torch
.
device
(
global_config
.
device
),
w_avg_samples
=
600
,
num_steps
=
hyperparameters
.
first_inv_steps
,
w_name
=
image_name
,
use_wandb
=
self
.
use_wandb
)
return
w
@
abc
.
abstractmethod
def
train
(
self
):
pass
def
configure_optimizers
(
self
):
optimizer
=
torch
.
optim
.
Adam
(
self
.
G
.
parameters
(),
lr
=
hyperparameters
.
pti_learning_rate
)
return
optimizer
def
calc_loss
(
self
,
generated_images
,
real_images
,
log_name
,
new_G
,
use_ball_holder
,
w_batch
):
loss
=
0.0
if
hyperparameters
.
pt_l2_lambda
>
0
:
l2_loss_val
=
l2_loss
(
generated_images
,
real_images
)
if
self
.
use_wandb
:
wandb
.
log
({
f
'MSE_loss_val_
{
log_name
}
'
:
l2_loss_val
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
loss
+=
l2_loss_val
*
hyperparameters
.
pt_l2_lambda
if
hyperparameters
.
pt_lpips_lambda
>
0
:
loss_lpips
=
self
.
lpips_loss
(
generated_images
,
real_images
)
loss_lpips
=
torch
.
squeeze
(
loss_lpips
)
if
self
.
use_wandb
:
wandb
.
log
({
f
'LPIPS_loss_val_
{
log_name
}
'
:
loss_lpips
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
loss
+=
loss_lpips
*
hyperparameters
.
pt_lpips_lambda
if
use_ball_holder
and
hyperparameters
.
use_locality_regularization
:
ball_holder_loss_val
=
self
.
space_regulizer
.
space_regulizer_loss
(
new_G
,
w_batch
,
use_wandb
=
self
.
use_wandb
)
loss
+=
ball_holder_loss_val
return
loss
,
l2_loss_val
,
loss_lpips
def
forward
(
self
,
w
):
generated_images
=
self
.
G
.
synthesis
(
w
,
noise_mode
=
'const'
,
force_fp32
=
True
)
return
generated_images
def
initilize_e4e
(
self
):
ckpt
=
torch
.
load
(
paths_config
.
e4e
,
map_location
=
'cpu'
)
opts
=
ckpt
[
'opts'
]
opts
[
'batch_size'
]
=
hyperparameters
.
train_batch_size
opts
[
'checkpoint_path'
]
=
paths_config
.
e4e
opts
=
Namespace
(
**
opts
)
self
.
e4e_inversion_net
=
pSp
(
opts
)
self
.
e4e_inversion_net
.
eval
()
self
.
e4e_inversion_net
=
self
.
e4e_inversion_net
.
to
(
global_config
.
device
)
toogle_grad
(
self
.
e4e_inversion_net
,
False
)
def
get_e4e_inversion
(
self
,
image
):
image
=
(
image
+
1
)
/
2
new_image
=
self
.
e4e_image_transform
(
image
[
0
]).
to
(
global_config
.
device
)
_
,
w
=
self
.
e4e_inversion_net
(
new_image
.
unsqueeze
(
0
),
randomize_noise
=
False
,
return_latents
=
True
,
resize
=
False
,
input_code
=
False
)
if
self
.
use_wandb
:
log_image_from_w
(
w
,
self
.
G
,
'First e4e inversion'
)
return
w
stylegan_human/pti/training/coaches/localitly_regulizer.py
0 → 100644
View file @
fba8bde8
import
torch
import
numpy
as
np
import
wandb
from
pti.pti_configs
import
hyperparameters
,
global_config
l2_criterion
=
torch
.
nn
.
MSELoss
(
reduction
=
'mean'
)
def
l2_loss
(
real_images
,
generated_images
):
loss
=
l2_criterion
(
real_images
,
generated_images
)
return
loss
class
Space_Regulizer
:
def
__init__
(
self
,
original_G
,
lpips_net
):
self
.
original_G
=
original_G
self
.
morphing_regulizer_alpha
=
hyperparameters
.
regulizer_alpha
self
.
lpips_loss
=
lpips_net
def
get_morphed_w_code
(
self
,
new_w_code
,
fixed_w
):
interpolation_direction
=
new_w_code
-
fixed_w
interpolation_direction_norm
=
torch
.
norm
(
interpolation_direction
,
p
=
2
)
direction_to_move
=
hyperparameters
.
regulizer_alpha
*
interpolation_direction
/
interpolation_direction_norm
result_w
=
fixed_w
+
direction_to_move
self
.
morphing_regulizer_alpha
*
fixed_w
+
(
1
-
self
.
morphing_regulizer_alpha
)
*
new_w_code
return
result_w
def
get_image_from_ws
(
self
,
w_codes
,
G
):
return
torch
.
cat
([
G
.
synthesis
(
w_code
,
noise_mode
=
'none'
,
force_fp32
=
True
)
for
w_code
in
w_codes
])
def
ball_holder_loss_lazy
(
self
,
new_G
,
num_of_sampled_latents
,
w_batch
,
use_wandb
=
False
):
loss
=
0.0
z_samples
=
np
.
random
.
randn
(
num_of_sampled_latents
,
self
.
original_G
.
z_dim
)
w_samples
=
self
.
original_G
.
mapping
(
torch
.
from_numpy
(
z_samples
).
to
(
global_config
.
device
),
None
,
truncation_psi
=
0.5
)
territory_indicator_ws
=
[
self
.
get_morphed_w_code
(
w_code
.
unsqueeze
(
0
),
w_batch
)
for
w_code
in
w_samples
]
for
w_code
in
territory_indicator_ws
:
new_img
=
new_G
.
synthesis
(
w_code
,
noise_mode
=
'none'
,
force_fp32
=
True
)
with
torch
.
no_grad
():
old_img
=
self
.
original_G
.
synthesis
(
w_code
,
noise_mode
=
'none'
,
force_fp32
=
True
)
if
hyperparameters
.
regulizer_l2_lambda
>
0
:
l2_loss_val
=
l2_loss
.
l2_loss
(
old_img
,
new_img
)
if
use_wandb
:
wandb
.
log
({
f
'space_regulizer_l2_loss_val'
:
l2_loss_val
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
loss
+=
l2_loss_val
*
hyperparameters
.
regulizer_l2_lambda
if
hyperparameters
.
regulizer_lpips_lambda
>
0
:
loss_lpips
=
self
.
lpips_loss
(
old_img
,
new_img
)
loss_lpips
=
torch
.
mean
(
torch
.
squeeze
(
loss_lpips
))
if
use_wandb
:
wandb
.
log
({
f
'space_regulizer_lpips_loss_val'
:
loss_lpips
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
loss
+=
loss_lpips
*
hyperparameters
.
regulizer_lpips_lambda
return
loss
/
len
(
territory_indicator_ws
)
def
space_regulizer_loss
(
self
,
new_G
,
w_batch
,
use_wandb
):
ret_val
=
self
.
ball_holder_loss_lazy
(
new_G
,
hyperparameters
.
latent_ball_num_of_samples
,
w_batch
,
use_wandb
)
return
ret_val
stylegan_human/pti/training/coaches/multi_id_coach.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
import
os
import
torch
from
tqdm
import
tqdm
from
pti.pti_configs
import
paths_config
,
hyperparameters
,
global_config
from
pti.training.coaches.base_coach
import
BaseCoach
from
utils.log_utils
import
log_images_from_w
class
MultiIDCoach
(
BaseCoach
):
def
__init__
(
self
,
data_loader
,
use_wandb
):
super
().
__init__
(
data_loader
,
use_wandb
)
def
train
(
self
):
self
.
G
.
synthesis
.
train
()
self
.
G
.
mapping
.
train
()
w_path_dir
=
f
'
{
paths_config
.
embedding_base_dir
}
/
{
paths_config
.
input_data_id
}
'
os
.
makedirs
(
w_path_dir
,
exist_ok
=
True
)
os
.
makedirs
(
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
'
,
exist_ok
=
True
)
use_ball_holder
=
True
w_pivots
=
[]
images
=
[]
for
fname
,
image
in
self
.
data_loader
:
if
self
.
image_counter
>=
hyperparameters
.
max_images_to_invert
:
break
image_name
=
fname
[
0
]
if
hyperparameters
.
first_inv_type
==
'w+'
:
embedding_dir
=
f
'
{
w_path_dir
}
/
{
paths_config
.
e4e_results_keyword
}
/
{
image_name
}
'
else
:
embedding_dir
=
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
/
{
image_name
}
'
os
.
makedirs
(
embedding_dir
,
exist_ok
=
True
)
w_pivot
=
self
.
get_inversion
(
w_path_dir
,
image_name
,
image
)
w_pivots
.
append
(
w_pivot
)
images
.
append
((
image_name
,
image
))
self
.
image_counter
+=
1
for
i
in
tqdm
(
range
(
hyperparameters
.
max_pti_steps
)):
self
.
image_counter
=
0
for
data
,
w_pivot
in
zip
(
images
,
w_pivots
):
image_name
,
image
=
data
if
self
.
image_counter
>=
hyperparameters
.
max_images_to_invert
:
break
real_images_batch
=
image
.
to
(
global_config
.
device
)
generated_images
=
self
.
forward
(
w_pivot
)
loss
,
l2_loss_val
,
loss_lpips
=
self
.
calc_loss
(
generated_images
,
real_images_batch
,
image_name
,
self
.
G
,
use_ball_holder
,
w_pivot
)
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
use_ball_holder
=
global_config
.
training_step
%
hyperparameters
.
locality_regularization_interval
==
0
global_config
.
training_step
+=
1
self
.
image_counter
+=
1
if
self
.
use_wandb
:
log_images_from_w
(
w_pivots
,
self
.
G
,
[
image
[
0
]
for
image
in
images
])
# torch.save(self.G,
# f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt')
snapshot_data
=
dict
()
snapshot_data
[
'G_ema'
]
=
self
.
G
import
pickle
with
open
(
f
'
{
paths_config
.
checkpoints_dir
}
/model_
{
global_config
.
run_name
}
_multi_id.pkl'
,
'wb'
)
as
f
:
pickle
.
dump
(
snapshot_data
,
f
)
stylegan_human/pti/training/coaches/single_id_coach.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
import
os
import
torch
from
tqdm
import
tqdm
from
pti.pti_configs
import
paths_config
,
hyperparameters
,
global_config
from
pti.training.coaches.base_coach
import
BaseCoach
from
utils.log_utils
import
log_images_from_w
from
torchvision.utils
import
save_image
class
SingleIDCoach
(
BaseCoach
):
def
__init__
(
self
,
data_loader
,
use_wandb
):
super
().
__init__
(
data_loader
,
use_wandb
)
def
train
(
self
):
w_path_dir
=
f
'
{
paths_config
.
embedding_base_dir
}
/
{
paths_config
.
input_data_id
}
'
os
.
makedirs
(
w_path_dir
,
exist_ok
=
True
)
os
.
makedirs
(
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
'
,
exist_ok
=
True
)
use_ball_holder
=
True
for
fname
,
image
in
tqdm
(
self
.
data_loader
):
image_name
=
fname
[
0
]
self
.
restart_training
()
if
self
.
image_counter
>=
hyperparameters
.
max_images_to_invert
:
break
embedding_dir
=
f
'
{
w_path_dir
}
/
{
paths_config
.
pti_results_keyword
}
/
{
image_name
}
'
os
.
makedirs
(
embedding_dir
,
exist_ok
=
True
)
w_pivot
=
None
if
hyperparameters
.
use_last_w_pivots
:
w_pivot
=
self
.
load_inversions
(
w_path_dir
,
image_name
)
# Copyright (c) SenseTime Research. All rights reserved.
elif
not
hyperparameters
.
use_last_w_pivots
or
w_pivot
is
None
:
w_pivot
=
self
.
calc_inversions
(
image
,
image_name
)
# w_pivot = w_pivot.detach().clone().to(global_config.device)
w_pivot
=
w_pivot
.
to
(
global_config
.
device
)
torch
.
save
(
w_pivot
,
f
'
{
embedding_dir
}
/0.pt'
)
log_images_counter
=
0
real_images_batch
=
image
.
to
(
global_config
.
device
)
for
i
in
range
(
hyperparameters
.
max_pti_steps
):
generated_images
=
self
.
forward
(
w_pivot
)
loss
,
l2_loss_val
,
loss_lpips
=
self
.
calc_loss
(
generated_images
,
real_images_batch
,
image_name
,
self
.
G
,
use_ball_holder
,
w_pivot
)
if
i
==
0
:
tmp1
=
torch
.
clone
(
generated_images
)
if
i
%
10
==
0
:
print
(
"pti loss: "
,
i
,
loss
.
data
,
loss_lpips
.
data
)
self
.
optimizer
.
zero_grad
()
if
loss_lpips
<=
hyperparameters
.
LPIPS_value_threshold
:
break
loss
.
backward
()
self
.
optimizer
.
step
()
use_ball_holder
=
global_config
.
training_step
%
hyperparameters
.
locality_regularization_interval
==
0
if
self
.
use_wandb
and
log_images_counter
%
global_config
.
image_rec_result_log_snapshot
==
0
:
log_images_from_w
([
w_pivot
],
self
.
G
,
[
image_name
])
global_config
.
training_step
+=
1
log_images_counter
+=
1
# save output image
tmp
=
torch
.
cat
([
real_images_batch
,
tmp1
,
generated_images
],
axis
=
3
)
save_image
(
tmp
,
f
"
{
paths_config
.
experiments_output_dir
}
/
{
image_name
}
.png"
,
normalize
=
True
)
self
.
image_counter
+=
1
# torch.save(self.G,
# f'{paths_config.checkpoints_dir}/model_{image_name}.pt') #'.pt'
snapshot_data
=
dict
()
snapshot_data
[
'G_ema'
]
=
self
.
G
import
pickle
with
open
(
f
'
{
paths_config
.
checkpoints_dir
}
/model_
{
image_name
}
.pkl'
,
'wb'
)
as
f
:
pickle
.
dump
(
snapshot_data
,
f
)
stylegan_human/pti/training/projectors/__init__.py
0 → 100644
View file @
fba8bde8
stylegan_human/pti/training/projectors/w_plus_projector.py
0 → 100644
View file @
fba8bde8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import
copy
import
wandb
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
configs
import
global_config
,
hyperparameters
import
dnnlib
from
utils.log_utils
import
log_image_from_w
def
project
(
G
,
target
:
torch
.
Tensor
,
# [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*
,
num_steps
=
1000
,
w_avg_samples
=
10000
,
initial_learning_rate
=
0.01
,
initial_noise_factor
=
0.05
,
lr_rampdown_length
=
0.25
,
lr_rampup_length
=
0.05
,
noise_ramp_length
=
0.75
,
regularize_noise_weight
=
1e5
,
verbose
=
False
,
device
:
torch
.
device
,
use_wandb
=
False
,
initial_w
=
None
,
image_log_step
=
global_config
.
image_rec_result_log_snapshot
,
w_name
:
str
):
print
(
'inside training/projectors/w_plus_projector'
)
print
(
target
.
shape
,
G
.
img_channels
,
G
.
img_resolution
*
2
,
G
.
img_resolution
)
assert
target
.
shape
==
(
G
.
img_channels
,
G
.
img_resolution
*
2
,
G
.
img_resolution
)
def
logprint
(
*
args
):
if
verbose
:
print
(
*
args
)
G
=
copy
.
deepcopy
(
G
).
eval
().
requires_grad_
(
False
).
to
(
device
).
float
()
# type: ignore
# Compute w stats.
logprint
(
f
'Computing W midpoint and stddev using
{
w_avg_samples
}
samples...'
)
z_samples
=
np
.
random
.
RandomState
(
123
).
randn
(
w_avg_samples
,
G
.
z_dim
)
w_samples
=
G
.
mapping
(
torch
.
from_numpy
(
z_samples
).
to
(
device
),
None
)
# [N, L, C]
w_samples
=
w_samples
[:,
:
1
,
:].
cpu
().
numpy
().
astype
(
np
.
float32
)
# [N, 1, C]
w_avg
=
np
.
mean
(
w_samples
,
axis
=
0
,
keepdims
=
True
)
# [1, 1, C]
w_avg_tensor
=
torch
.
from_numpy
(
w_avg
).
to
(
global_config
.
device
)
w_std
=
(
np
.
sum
((
w_samples
-
w_avg
)
**
2
)
/
w_avg_samples
)
**
0.5
start_w
=
initial_w
if
initial_w
is
not
None
else
w_avg
# Setup noise inputs.
noise_bufs
=
{
name
:
buf
for
(
name
,
buf
)
in
G
.
synthesis
.
named_buffers
()
if
'noise_const'
in
name
}
# Load VGG16 feature detector.
url
=
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with
dnnlib
.
util
.
open_url
(
url
)
as
f
:
vgg16
=
torch
.
jit
.
load
(
f
).
eval
().
to
(
device
)
# Features for target image.
target_images
=
target
.
unsqueeze
(
0
).
to
(
device
).
to
(
torch
.
float32
)
if
target_images
.
shape
[
2
]
>
256
:
target_images
=
F
.
interpolate
(
target_images
,
size
=
(
256
,
256
),
mode
=
'area'
)
target_features
=
vgg16
(
target_images
,
resize_images
=
False
,
return_lpips
=
True
)
start_w
=
np
.
repeat
(
start_w
,
G
.
mapping
.
num_ws
,
axis
=
1
)
w_opt
=
torch
.
tensor
(
start_w
,
dtype
=
torch
.
float32
,
device
=
device
,
requires_grad
=
True
)
# pylint: disable=not-callable
optimizer
=
torch
.
optim
.
Adam
([
w_opt
]
+
list
(
noise_bufs
.
values
()),
betas
=
(
0.9
,
0.999
),
lr
=
hyperparameters
.
first_inv_lr
)
# Init noise.
for
buf
in
noise_bufs
.
values
():
buf
[:]
=
torch
.
randn_like
(
buf
)
buf
.
requires_grad
=
True
for
step
in
tqdm
(
range
(
num_steps
)):
# Learning rate schedule.
t
=
step
/
num_steps
w_noise_scale
=
w_std
*
initial_noise_factor
*
max
(
0.0
,
1.0
-
t
/
noise_ramp_length
)
**
2
lr_ramp
=
min
(
1.0
,
(
1.0
-
t
)
/
lr_rampdown_length
)
lr_ramp
=
0.5
-
0.5
*
np
.
cos
(
lr_ramp
*
np
.
pi
)
lr_ramp
=
lr_ramp
*
min
(
1.0
,
t
/
lr_rampup_length
)
lr
=
initial_learning_rate
*
lr_ramp
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
# Synth images from opt_w.
w_noise
=
torch
.
randn_like
(
w_opt
)
*
w_noise_scale
ws
=
(
w_opt
+
w_noise
)
synth_images
=
G
.
synthesis
(
ws
,
noise_mode
=
'const'
,
force_fp32
=
True
)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images
=
(
synth_images
+
1
)
*
(
255
/
2
)
if
synth_images
.
shape
[
2
]
>
256
:
synth_images
=
F
.
interpolate
(
synth_images
,
size
=
(
256
,
256
),
mode
=
'area'
)
# Features for synth images.
synth_features
=
vgg16
(
synth_images
,
resize_images
=
False
,
return_lpips
=
True
)
dist
=
(
target_features
-
synth_features
).
square
().
sum
()
# Noise regularization.
reg_loss
=
0.0
for
v
in
noise_bufs
.
values
():
noise
=
v
[
None
,
None
,
:,
:]
# must be [1,1,H,W] for F.avg_pool2d()
while
True
:
reg_loss
+=
(
noise
*
torch
.
roll
(
noise
,
shifts
=
1
,
dims
=
3
)).
mean
()
**
2
reg_loss
+=
(
noise
*
torch
.
roll
(
noise
,
shifts
=
1
,
dims
=
2
)).
mean
()
**
2
if
noise
.
shape
[
2
]
<=
8
:
break
noise
=
F
.
avg_pool2d
(
noise
,
kernel_size
=
2
)
loss
=
dist
+
reg_loss
*
regularize_noise_weight
if
step
%
image_log_step
==
0
:
with
torch
.
no_grad
():
if
use_wandb
:
global_config
.
training_step
+=
1
wandb
.
log
({
f
'first projection _
{
w_name
}
'
:
loss
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
log_image_from_w
(
w_opt
,
G
,
w_name
)
# Step
optimizer
.
zero_grad
(
set_to_none
=
True
)
loss
.
backward
()
optimizer
.
step
()
logprint
(
f
'step
{
step
+
1
:
>
4
d
}
/
{
num_steps
}
: dist
{
dist
:
<
4.2
f
}
loss
{
float
(
loss
):
<
5.2
f
}
'
)
# Normalize noise.
with
torch
.
no_grad
():
for
buf
in
noise_bufs
.
values
():
buf
-=
buf
.
mean
()
buf
*=
buf
.
square
().
mean
().
rsqrt
()
del
G
return
w_opt
stylegan_human/pti/training/projectors/w_projector.py
0 → 100644
View file @
fba8bde8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import
copy
import
wandb
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
pti.pti_configs
import
global_config
,
hyperparameters
from
utils
import
log_utils
import
dnnlib
def
project
(
G
,
target
:
torch
.
Tensor
,
# [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*
,
num_steps
=
1000
,
w_avg_samples
=
10000
,
initial_learning_rate
=
0.01
,
initial_noise_factor
=
0.05
,
lr_rampdown_length
=
0.25
,
lr_rampup_length
=
0.05
,
noise_ramp_length
=
0.75
,
regularize_noise_weight
=
1e5
,
verbose
=
False
,
device
:
torch
.
device
,
use_wandb
=
False
,
initial_w
=
None
,
image_log_step
=
global_config
.
image_rec_result_log_snapshot
,
w_name
:
str
):
print
(
target
.
shape
,
G
.
img_channels
,
G
.
img_resolution
,
G
.
img_resolution
//
2
)
assert
target
.
shape
==
(
G
.
img_channels
,
G
.
img_resolution
,
G
.
img_resolution
//
2
)
def
logprint
(
*
args
):
if
verbose
:
print
(
*
args
)
G
=
copy
.
deepcopy
(
G
).
eval
().
requires_grad_
(
False
).
to
(
device
).
float
()
# type: ignore
# Compute w stats.
logprint
(
f
'Computing W midpoint and stddev using
{
w_avg_samples
}
samples...'
)
z_samples
=
np
.
random
.
RandomState
(
123
).
randn
(
w_avg_samples
,
G
.
z_dim
)
w_samples
=
G
.
mapping
(
torch
.
from_numpy
(
z_samples
).
to
(
device
),
None
)
# [N, L, C]
w_samples
=
w_samples
[:,
:
1
,
:].
cpu
().
numpy
().
astype
(
np
.
float32
)
# [N, 1, C]
w_avg
=
np
.
mean
(
w_samples
,
axis
=
0
,
keepdims
=
True
)
# [1, 1, C]
w_avg_tensor
=
torch
.
from_numpy
(
w_avg
).
to
(
global_config
.
device
)
w_std
=
(
np
.
sum
((
w_samples
-
w_avg
)
**
2
)
/
w_avg_samples
)
**
0.5
start_w
=
initial_w
if
initial_w
is
not
None
else
w_avg
# Setup noise inputs.
noise_bufs
=
{
name
:
buf
for
(
name
,
buf
)
in
G
.
synthesis
.
named_buffers
()
if
'noise_const'
in
name
}
# Load VGG16 feature detector.
url
=
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with
dnnlib
.
util
.
open_url
(
url
)
as
f
:
vgg16
=
torch
.
jit
.
load
(
f
).
eval
().
to
(
device
)
# Features for target image.
target_images
=
target
.
unsqueeze
(
0
).
to
(
device
).
to
(
torch
.
float32
)
if
target_images
.
shape
[
2
]
>
256
:
target_images
=
F
.
interpolate
(
target_images
,
size
=
(
256
,
256
),
mode
=
'area'
)
target_features
=
vgg16
(
target_images
,
resize_images
=
False
,
return_lpips
=
True
)
w_opt
=
torch
.
tensor
(
start_w
,
dtype
=
torch
.
float32
,
device
=
device
,
requires_grad
=
True
)
# pylint: disable=not-callable
optimizer
=
torch
.
optim
.
Adam
([
w_opt
]
+
list
(
noise_bufs
.
values
()),
betas
=
(
0.9
,
0.999
),
lr
=
hyperparameters
.
first_inv_lr
)
# Init noise.
for
buf
in
noise_bufs
.
values
():
buf
[:]
=
torch
.
randn_like
(
buf
)
buf
.
requires_grad
=
True
for
step
in
range
(
num_steps
):
# Learning rate schedule.
t
=
step
/
num_steps
w_noise_scale
=
w_std
*
initial_noise_factor
*
max
(
0.0
,
1.0
-
t
/
noise_ramp_length
)
**
2
lr_ramp
=
min
(
1.0
,
(
1.0
-
t
)
/
lr_rampdown_length
)
lr_ramp
=
0.5
-
0.5
*
np
.
cos
(
lr_ramp
*
np
.
pi
)
lr_ramp
=
lr_ramp
*
min
(
1.0
,
t
/
lr_rampup_length
)
lr
=
initial_learning_rate
*
lr_ramp
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
# Synth images from opt_w.
w_noise
=
torch
.
randn_like
(
w_opt
)
*
w_noise_scale
ws
=
(
w_opt
+
w_noise
).
repeat
([
1
,
G
.
mapping
.
num_ws
,
1
])
synth_images
=
G
.
synthesis
(
ws
,
noise_mode
=
'const'
,
force_fp32
=
True
)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images
=
(
synth_images
+
1
)
*
(
255
/
2
)
if
synth_images
.
shape
[
2
]
>
256
:
synth_images
=
F
.
interpolate
(
synth_images
,
size
=
(
256
,
256
),
mode
=
'area'
)
# Features for synth images.
synth_features
=
vgg16
(
synth_images
,
resize_images
=
False
,
return_lpips
=
True
)
dist
=
(
target_features
-
synth_features
).
square
().
sum
()
# Noise regularization.
reg_loss
=
0.0
for
v
in
noise_bufs
.
values
():
noise
=
v
[
None
,
None
,
:,
:]
# must be [1,1,H,W] for F.avg_pool2d()
while
True
:
reg_loss
+=
(
noise
*
torch
.
roll
(
noise
,
shifts
=
1
,
dims
=
3
)).
mean
()
**
2
reg_loss
+=
(
noise
*
torch
.
roll
(
noise
,
shifts
=
1
,
dims
=
2
)).
mean
()
**
2
if
noise
.
shape
[
2
]
<=
8
:
break
noise
=
F
.
avg_pool2d
(
noise
,
kernel_size
=
2
)
loss
=
dist
+
reg_loss
*
regularize_noise_weight
if
step
%
10
==
0
:
print
(
"project loss"
,
step
,
loss
.
data
)
if
step
%
image_log_step
==
0
:
with
torch
.
no_grad
():
if
use_wandb
:
global_config
.
training_step
+=
1
wandb
.
log
({
f
'first projection _
{
w_name
}
'
:
loss
.
detach
().
cpu
()},
step
=
global_config
.
training_step
)
log_utils
.
log_image_from_w
(
w_opt
.
repeat
([
1
,
G
.
mapping
.
num_ws
,
1
]),
G
,
w_name
)
# Step
optimizer
.
zero_grad
(
set_to_none
=
True
)
loss
.
backward
()
optimizer
.
step
()
logprint
(
f
'step
{
step
+
1
:
>
4
d
}
/
{
num_steps
}
: dist
{
dist
:
<
4.2
f
}
loss
{
float
(
loss
):
<
5.2
f
}
'
)
# Normalize noise.
with
torch
.
no_grad
():
for
buf
in
noise_bufs
.
values
():
buf
-=
buf
.
mean
()
buf
*=
buf
.
square
().
mean
().
rsqrt
()
del
G
return
w_opt
.
repeat
([
1
,
18
,
1
])
stylegan_human/run_pti.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
from
random
import
choice
from
string
import
ascii_uppercase
from
torch.utils.data
import
DataLoader
from
torchvision.transforms
import
transforms
import
os
from
pti.pti_configs
import
global_config
,
paths_config
import
wandb
from
pti.training.coaches.multi_id_coach
import
MultiIDCoach
from
pti.training.coaches.single_id_coach
import
SingleIDCoach
from
utils.ImagesDataset
import
ImagesDataset
def
run_PTI
(
run_name
=
''
,
use_wandb
=
False
,
use_multi_id_training
=
False
):
os
.
environ
[
'CUDA_DEVICE_ORDER'
]
=
'PCI_BUS_ID'
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
global_config
.
cuda_visible_devices
if
run_name
==
''
:
global_config
.
run_name
=
''
.
join
(
choice
(
ascii_uppercase
)
for
i
in
range
(
12
))
else
:
global_config
.
run_name
=
run_name
if
use_wandb
:
run
=
wandb
.
init
(
project
=
paths_config
.
pti_results_keyword
,
reinit
=
True
,
name
=
global_config
.
run_name
)
global_config
.
pivotal_training_steps
=
1
global_config
.
training_step
=
1
embedding_dir_path
=
f
'
{
paths_config
.
embedding_base_dir
}
/
{
paths_config
.
input_data_id
}
/
{
paths_config
.
pti_results_keyword
}
'
# print('embedding_dir_path: ', embedding_dir_path) #./embeddings/barcelona/PTI
os
.
makedirs
(
embedding_dir_path
,
exist_ok
=
True
)
dataset
=
ImagesDataset
(
paths_config
.
input_data_path
,
transforms
.
Compose
([
transforms
.
Resize
((
1024
,
512
)),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])]))
dataloader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
False
)
if
use_multi_id_training
:
coach
=
MultiIDCoach
(
dataloader
,
use_wandb
)
else
:
coach
=
SingleIDCoach
(
dataloader
,
use_wandb
)
coach
.
train
()
return
global_config
.
run_name
if
__name__
==
'__main__'
:
run_PTI
(
run_name
=
''
,
use_wandb
=
False
,
use_multi_id_training
=
False
)
stylegan_human/style_mixing.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#
import
os
import
re
from
typing
import
List
import
legacy
import
click
import
dnnlib
import
numpy
as
np
import
PIL.Image
import
torch
"""
Style mixing using pretrained network pickle.
Examples:
\b
python style_mixing.py --network=pretrained_models/stylegan_human_v2_1024.pkl --rows=85,100,75,458,1500
\\
--cols=55,821,1789,293 --styles=0-3 --outdir=outputs/stylemixing
"""
@
click
.
command
()
@
click
.
option
(
'--network'
,
'network_pkl'
,
help
=
'Network pickle filename'
,
required
=
True
)
@
click
.
option
(
'--rows'
,
'row_seeds'
,
type
=
legacy
.
num_range
,
help
=
'Random seeds to use for image rows'
,
required
=
True
)
@
click
.
option
(
'--cols'
,
'col_seeds'
,
type
=
legacy
.
num_range
,
help
=
'Random seeds to use for image columns'
,
required
=
True
)
@
click
.
option
(
'--styles'
,
'col_styles'
,
type
=
legacy
.
num_range
,
help
=
'Style layer range'
,
default
=
'0-6'
,
show_default
=
True
)
@
click
.
option
(
'--trunc'
,
'truncation_psi'
,
type
=
float
,
help
=
'Truncation psi'
,
default
=
0.8
,
show_default
=
True
)
@
click
.
option
(
'--noise-mode'
,
help
=
'Noise mode'
,
type
=
click
.
Choice
([
'const'
,
'random'
,
'none'
]),
default
=
'const'
,
show_default
=
True
)
@
click
.
option
(
'--outdir'
,
type
=
str
,
required
=
True
,
default
=
'outputs/stylemixing'
)
def
generate_style_mix
(
network_pkl
:
str
,
row_seeds
:
List
[
int
],
col_seeds
:
List
[
int
],
col_styles
:
List
[
int
],
truncation_psi
:
float
,
noise_mode
:
str
,
outdir
:
str
):
print
(
'Loading networks from "%s"...'
%
network_pkl
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'mps'
if
torch
.
backends
.
mps
.
is_available
()
else
'cpu'
)
dtype
=
torch
.
float32
if
device
.
type
==
'mps'
else
torch
.
float64
with
dnnlib
.
util
.
open_url
(
network_pkl
)
as
f
:
G
=
legacy
.
load_network_pkl
(
f
)[
'G_ema'
].
to
(
device
,
dtype
=
dtype
)
os
.
makedirs
(
outdir
,
exist_ok
=
True
)
print
(
'Generating W vectors...'
)
all_seeds
=
list
(
set
(
row_seeds
+
col_seeds
))
all_z
=
np
.
stack
([
np
.
random
.
RandomState
(
seed
).
randn
(
G
.
z_dim
)
for
seed
in
all_seeds
])
all_w
=
G
.
mapping
(
torch
.
from_numpy
(
all_z
).
to
(
device
,
dtype
=
dtype
),
None
)
w_avg
=
G
.
mapping
.
w_avg
all_w
=
w_avg
+
(
all_w
-
w_avg
)
*
truncation_psi
w_dict
=
{
seed
:
w
for
seed
,
w
in
zip
(
all_seeds
,
list
(
all_w
))}
print
(
'Generating images...'
)
all_images
=
G
.
synthesis
(
all_w
,
noise_mode
=
noise_mode
)
all_images
=
(
all_images
.
permute
(
0
,
2
,
3
,
1
)
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_dict
=
{(
seed
,
seed
):
image
for
seed
,
image
in
zip
(
all_seeds
,
list
(
all_images
))}
print
(
'Generating style-mixed images...'
)
for
row_seed
in
row_seeds
:
for
col_seed
in
col_seeds
:
w
=
w_dict
[
row_seed
].
clone
()
w
[
col_styles
]
=
w_dict
[
col_seed
][
col_styles
]
image
=
G
.
synthesis
(
w
[
np
.
newaxis
],
noise_mode
=
noise_mode
)
image
=
(
image
.
permute
(
0
,
2
,
3
,
1
)
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
image_dict
[(
row_seed
,
col_seed
)]
=
image
[
0
].
cpu
().
numpy
()
os
.
makedirs
(
outdir
,
exist_ok
=
True
)
# print('Saving images...')
# for (row_seed, col_seed), image in image_dict.items():
# PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
print
(
'Saving image grid...'
)
W
=
G
.
img_resolution
//
2
H
=
G
.
img_resolution
canvas
=
PIL
.
Image
.
new
(
'RGB'
,
(
W
*
(
len
(
col_seeds
)
+
1
),
H
*
(
len
(
row_seeds
)
+
1
)),
'black'
)
for
row_idx
,
row_seed
in
enumerate
([
0
]
+
row_seeds
):
for
col_idx
,
col_seed
in
enumerate
([
0
]
+
col_seeds
):
if
row_idx
==
0
and
col_idx
==
0
:
continue
key
=
(
row_seed
,
col_seed
)
if
row_idx
==
0
:
key
=
(
col_seed
,
col_seed
)
if
col_idx
==
0
:
key
=
(
row_seed
,
row_seed
)
canvas
.
paste
(
PIL
.
Image
.
fromarray
(
image_dict
[
key
],
'RGB'
),
(
W
*
col_idx
,
H
*
row_idx
))
canvas
.
save
(
f
'
{
outdir
}
/grid.png'
)
#----------------------------------------------------------------------------
if
__name__
==
"__main__"
:
generate_style_mix
()
# pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
stylegan_human/stylemixing_video.py
0 → 100755
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
"""Here we demo style-mixing results using StyleGAN2 pretrained model.
Script reference: https://github.com/PDillis/stylegan2-fun """
import
argparse
import
legacy
import
scipy
import
numpy
as
np
import
PIL.Image
import
dnnlib
import
dnnlib.tflib
as
tflib
from
typing
import
List
import
re
import
sys
import
os
import
click
import
torch
os
.
environ
[
'PYGAME_HIDE_SUPPORT_PROMPT'
]
=
"hide"
import
moviepy.editor
"""
Generate style mixing video.
Examples:
\b
python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859
\\
--col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video
"""
@
click
.
command
()
@
click
.
option
(
'--network'
,
'network_pkl'
,
help
=
'Path to network pickle filename'
,
required
=
True
)
@
click
.
option
(
'--row-seed'
,
'src_seed'
,
type
=
legacy
.
num_range
,
help
=
'Random seed to use for image source row'
,
required
=
True
)
@
click
.
option
(
'--col-seeds'
,
'dst_seeds'
,
type
=
legacy
.
num_range
,
help
=
'Random seeds to use for image columns (style)'
,
required
=
True
)
@
click
.
option
(
'--col-styles'
,
'col_styles'
,
type
=
legacy
.
num_range
,
help
=
'Style layer range (default: %(default)s)'
,
default
=
'0-6'
)
@
click
.
option
(
'--only-stylemix'
,
'only_stylemix'
,
help
=
'Add flag to only show the style mxied images in the video'
,
default
=
False
)
@
click
.
option
(
'--trunc'
,
'truncation_psi'
,
type
=
float
,
help
=
'Truncation psi (default: %(default)s)'
,
default
=
1
)
@
click
.
option
(
'--duration-sec'
,
'duration_sec'
,
type
=
float
,
help
=
'Duration of video (default: %(default)s)'
,
default
=
10
)
@
click
.
option
(
'--fps'
,
'mp4_fps'
,
type
=
int
,
help
=
'FPS of generated video (default: %(default)s)'
,
default
=
10
)
@
click
.
option
(
'--indent-range'
,
'indent_range'
,
type
=
int
,
default
=
30
)
@
click
.
option
(
'--outdir'
,
help
=
'Root directory for run results (default: %(default)s)'
,
default
=
'outputs/stylemixing_video'
,
metavar
=
'DIR'
)
def
style_mixing_video
(
network_pkl
:
str
,
src_seed
:
List
[
int
],
# Seed of the source image style (row)
dst_seeds
:
List
[
int
],
# Seeds of the destination image styles (columns)
col_styles
:
List
[
int
],
# Styles to transfer from first row to first column
truncation_psi
=
float
,
only_stylemix
=
bool
,
# True if user wishes to show only thre style transferred result
duration_sec
=
float
,
smoothing_sec
=
1.0
,
mp4_fps
=
int
,
mp4_codec
=
"libx264"
,
mp4_bitrate
=
"16M"
,
minibatch_size
=
8
,
noise_mode
=
'const'
,
indent_range
=
int
,
outdir
=
str
):
# Calculate the number of frames:
print
(
'col_seeds: '
,
dst_seeds
)
num_frames
=
int
(
np
.
rint
(
duration_sec
*
mp4_fps
))
print
(
'Loading networks from "%s"...'
%
network_pkl
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'mps'
if
torch
.
backends
.
mps
.
is_available
()
else
'cpu'
)
dtype
=
torch
.
float32
if
device
.
type
==
'mps'
else
torch
.
float64
with
dnnlib
.
util
.
open_url
(
network_pkl
)
as
f
:
Gs
=
legacy
.
load_network_pkl
(
f
)[
'G_ema'
].
to
(
device
,
dtype
=
dtype
)
print
(
Gs
.
num_ws
,
Gs
.
w_dim
,
Gs
.
img_resolution
)
max_style
=
int
(
2
*
np
.
log2
(
Gs
.
img_resolution
))
-
3
assert
max
(
col_styles
)
<=
max_style
,
f
"Maximum col-style allowed:
{
max_style
}
"
# Left col latents
print
(
'Generating Source W vectors...'
)
src_shape
=
[
num_frames
]
+
[
Gs
.
z_dim
]
src_z
=
np
.
random
.
RandomState
(
*
src_seed
).
randn
(
*
src_shape
).
astype
(
np
.
float32
)
# [frames, src, component]
src_z
=
scipy
.
ndimage
.
gaussian_filter
(
src_z
,
[
smoothing_sec
*
mp4_fps
]
+
[
0
]
*
(
2
-
1
),
mode
=
"wrap"
)
src_z
/=
np
.
sqrt
(
np
.
mean
(
np
.
square
(
src_z
)))
# Map into the detangled latent space W and do truncation trick
src_w
=
Gs
.
mapping
(
torch
.
from_numpy
(
src_z
).
to
(
device
,
dtype
=
dtype
),
None
)
w_avg
=
Gs
.
mapping
.
w_avg
src_w
=
w_avg
+
(
src_w
-
w_avg
)
*
truncation_psi
# Top row latents (fixed reference)
print
(
'Generating Destination W vectors...'
)
dst_z
=
np
.
stack
([
np
.
random
.
RandomState
(
seed
).
randn
(
Gs
.
z_dim
)
for
seed
in
dst_seeds
])
dst_w
=
Gs
.
mapping
(
torch
.
from_numpy
(
dst_z
).
to
(
device
,
dtype
=
dtype
),
None
)
dst_w
=
w_avg
+
(
dst_w
-
w_avg
)
*
truncation_psi
# Get the width and height of each image:
H
=
Gs
.
img_resolution
# 1024
W
=
Gs
.
img_resolution
//
2
# 512
# Generate ALL the source images:
src_images
=
Gs
.
synthesis
(
src_w
,
noise_mode
=
noise_mode
)
src_images
=
(
src_images
.
permute
(
0
,
2
,
3
,
1
)
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
# Generate the column images:
dst_images
=
Gs
.
synthesis
(
dst_w
,
noise_mode
=
noise_mode
)
dst_images
=
(
dst_images
.
permute
(
0
,
2
,
3
,
1
)
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
print
(
'Generating full video (including source and destination images)'
)
# Generate our canvas where we will paste all the generated images:
canvas
=
PIL
.
Image
.
new
(
"RGB"
,
((
W
-
indent_range
)
*
(
len
(
dst_seeds
)
+
1
),
H
*
(
len
(
src_seed
)
+
1
)),
"white"
)
# W, H
for
col
,
dst_image
in
enumerate
(
list
(
dst_images
)):
#dst_image:[3,1024,512]
canvas
.
paste
(
PIL
.
Image
.
fromarray
(
dst_image
.
cpu
().
numpy
(),
"RGB"
),
((
col
+
1
)
*
(
W
-
indent_range
),
0
))
#H
# Aux functions: Frame generation func for moviepy.
def
make_frame
(
t
):
# Get the frame number according to time t:
frame_idx
=
int
(
np
.
clip
(
np
.
round
(
t
*
mp4_fps
),
0
,
num_frames
-
1
))
# We wish the image belonging to the frame at time t:
src_image
=
src_images
[
frame_idx
]
# always in the same place
canvas
.
paste
(
PIL
.
Image
.
fromarray
(
src_image
.
cpu
().
numpy
(),
"RGB"
),
(
0
-
indent_range
,
H
))
# Paste it to the lower left
# Now, for each of the column images:
for
col
,
dst_image
in
enumerate
(
list
(
dst_images
)):
# Select the pertinent latent w column:
w_col
=
np
.
stack
([
dst_w
[
col
].
cpu
()])
# [18, 512] -> [1, 18, 512]
w_col
=
torch
.
from_numpy
(
w_col
).
to
(
device
,
dtype
=
dtype
)
# Replace the values defined by col_styles:
w_col
[:,
col_styles
]
=
src_w
[
frame_idx
,
col_styles
]
#.cpu()
# Generate these synthesized images:
col_images
=
Gs
.
synthesis
(
w_col
,
noise_mode
=
noise_mode
)
col_images
=
(
col_images
.
permute
(
0
,
2
,
3
,
1
)
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
# Paste them in their respective spot:
for
row
,
image
in
enumerate
(
list
(
col_images
)):
canvas
.
paste
(
PIL
.
Image
.
fromarray
(
image
.
cpu
().
numpy
(),
"RGB"
),
((
col
+
1
)
*
(
W
-
indent_range
),
(
row
+
1
)
*
H
),
)
return
np
.
array
(
canvas
)
# Generate video using make_frame:
print
(
'Generating style-mixed video...'
)
videoclip
=
moviepy
.
editor
.
VideoClip
(
make_frame
,
duration
=
duration_sec
)
grid_size
=
[
len
(
dst_seeds
),
len
(
src_seed
)]
mp4
=
"{}x{}-style-mixing_{}_{}.mp4"
.
format
(
*
grid_size
,
min
(
col_styles
),
max
(
col_styles
))
if
not
os
.
path
.
exists
(
outdir
):
os
.
makedirs
(
outdir
)
videoclip
.
write_videofile
(
os
.
path
.
join
(
outdir
,
mp4
),
fps
=
mp4_fps
,
codec
=
mp4_codec
,
bitrate
=
mp4_bitrate
)
if
__name__
==
"__main__"
:
style_mixing_video
()
stylegan_human/torch_utils/__init__.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
stylegan_human/torch_utils/custom_ops.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
os
import
glob
import
torch
import
torch.utils.cpp_extension
import
importlib
import
hashlib
import
shutil
from
pathlib
import
Path
import
re
import
uuid
from
torch.utils.file_baton
import
FileBaton
#----------------------------------------------------------------------------
# Global options.
verbosity
=
'brief'
# Verbosity level: 'none', 'brief', 'full'
#----------------------------------------------------------------------------
# Internal helper funcs.
def
_find_compiler_bindir
():
patterns
=
[
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'
,
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'
,
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'
,
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin'
,
]
for
pattern
in
patterns
:
matches
=
sorted
(
glob
.
glob
(
pattern
))
if
len
(
matches
):
return
matches
[
-
1
]
return
None
def
_get_mangled_gpu_name
():
name
=
torch
.
cuda
.
get_device_name
().
lower
()
out
=
[]
for
c
in
name
:
if
re
.
match
(
'[a-z0-9_-]+'
,
c
):
out
.
append
(
c
)
else
:
out
.
append
(
'-'
)
return
''
.
join
(
out
)
#----------------------------------------------------------------------------
# Main entry point for compiling and loading C++/CUDA plugins.
_cached_plugins
=
dict
()
def
get_plugin
(
module_name
,
sources
,
**
build_kwargs
):
assert
verbosity
in
[
'none'
,
'brief'
,
'full'
]
# Already cached?
if
module_name
in
_cached_plugins
:
return
_cached_plugins
[
module_name
]
# Print status.
if
verbosity
==
'full'
:
print
(
f
'Setting up PyTorch plugin "
{
module_name
}
"...'
)
elif
verbosity
==
'brief'
:
print
(
f
'Setting up PyTorch plugin "
{
module_name
}
"... '
,
end
=
''
,
flush
=
True
)
try
:
# pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if
os
.
name
==
'nt'
and
os
.
system
(
"where cl.exe >nul 2>nul"
)
!=
0
:
compiler_bindir
=
_find_compiler_bindir
()
if
compiler_bindir
is
None
:
raise
RuntimeError
(
f
'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "
{
__file__
}
".'
)
os
.
environ
[
'PATH'
]
+=
';'
+
compiler_bindir
# Compile and load.
verbose_build
=
(
verbosity
==
'full'
)
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
source_dirs_set
=
set
(
os
.
path
.
dirname
(
source
)
for
source
in
sources
)
if
len
(
source_dirs_set
)
==
1
and
(
'TORCH_EXTENSIONS_DIR'
in
os
.
environ
):
all_source_files
=
sorted
(
list
(
x
for
x
in
Path
(
list
(
source_dirs_set
)[
0
]).
iterdir
()
if
x
.
is_file
()))
# Compute a combined hash digest for all source files in the same
# custom op directory (usually .cu, .cpp, .py and .h files).
hash_md5
=
hashlib
.
md5
()
for
src
in
all_source_files
:
with
open
(
src
,
'rb'
)
as
f
:
hash_md5
.
update
(
f
.
read
())
build_dir
=
torch
.
utils
.
cpp_extension
.
_get_build_directory
(
module_name
,
verbose
=
verbose_build
)
# pylint: disable=protected-access
digest_build_dir
=
os
.
path
.
join
(
build_dir
,
hash_md5
.
hexdigest
())
if
not
os
.
path
.
isdir
(
digest_build_dir
):
os
.
makedirs
(
digest_build_dir
,
exist_ok
=
True
)
baton
=
FileBaton
(
os
.
path
.
join
(
digest_build_dir
,
'lock'
))
if
baton
.
try_acquire
():
try
:
for
src
in
all_source_files
:
shutil
.
copyfile
(
src
,
os
.
path
.
join
(
digest_build_dir
,
os
.
path
.
basename
(
src
)))
finally
:
baton
.
release
()
else
:
# Someone else is copying source files under the digest dir,
# wait until done and continue.
baton
.
wait
()
digest_sources
=
[
os
.
path
.
join
(
digest_build_dir
,
os
.
path
.
basename
(
x
))
for
x
in
sources
]
torch
.
utils
.
cpp_extension
.
load
(
name
=
module_name
,
build_directory
=
build_dir
,
verbose
=
verbose_build
,
sources
=
digest_sources
,
**
build_kwargs
)
else
:
torch
.
utils
.
cpp_extension
.
load
(
name
=
module_name
,
verbose
=
verbose_build
,
sources
=
sources
,
**
build_kwargs
)
module
=
importlib
.
import_module
(
module_name
)
except
:
if
verbosity
==
'brief'
:
print
(
'Failed!'
)
raise
# Print status and add to cache.
if
verbosity
==
'full'
:
print
(
f
'Done setting up PyTorch plugin "
{
module_name
}
".'
)
elif
verbosity
==
'brief'
:
print
(
'Done.'
)
_cached_plugins
[
module_name
]
=
module
return
module
#----------------------------------------------------------------------------
def
get_plugin_v3
(
module_name
,
sources
,
headers
=
None
,
source_dir
=
None
,
**
build_kwargs
):
assert
verbosity
in
[
'none'
,
'brief'
,
'full'
]
if
headers
is
None
:
headers
=
[]
if
source_dir
is
not
None
:
sources
=
[
os
.
path
.
join
(
source_dir
,
fname
)
for
fname
in
sources
]
headers
=
[
os
.
path
.
join
(
source_dir
,
fname
)
for
fname
in
headers
]
# Already cached?
if
module_name
in
_cached_plugins
:
return
_cached_plugins
[
module_name
]
# Print status.
if
verbosity
==
'full'
:
print
(
f
'Setting up PyTorch plugin "
{
module_name
}
"...'
)
elif
verbosity
==
'brief'
:
print
(
f
'Setting up PyTorch plugin "
{
module_name
}
"... '
,
end
=
''
,
flush
=
True
)
verbose_build
=
(
verbosity
==
'full'
)
# Compile and load.
try
:
# pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if
os
.
name
==
'nt'
and
os
.
system
(
"where cl.exe >nul 2>nul"
)
!=
0
:
compiler_bindir
=
_find_compiler_bindir
()
if
compiler_bindir
is
None
:
raise
RuntimeError
(
f
'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "
{
__file__
}
".'
)
os
.
environ
[
'PATH'
]
+=
';'
+
compiler_bindir
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
# break the build or unnecessarily restrict what's available to nvcc.
# Unset it to let nvcc decide based on what's available on the
# machine.
os
.
environ
[
'TORCH_CUDA_ARCH_LIST'
]
=
''
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
#
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
# around the *.cu dependency bug in ninja config.
#
all_source_files
=
sorted
(
sources
+
headers
)
all_source_dirs
=
set
(
os
.
path
.
dirname
(
fname
)
for
fname
in
all_source_files
)
if
len
(
all_source_dirs
)
==
1
:
# and ('TORCH_EXTENSIONS_DIR' in os.environ):
# Compute combined hash digest for all source files.
hash_md5
=
hashlib
.
md5
()
for
src
in
all_source_files
:
with
open
(
src
,
'rb'
)
as
f
:
hash_md5
.
update
(
f
.
read
())
# Select cached build directory name.
source_digest
=
hash_md5
.
hexdigest
()
build_top_dir
=
torch
.
utils
.
cpp_extension
.
_get_build_directory
(
module_name
,
verbose
=
verbose_build
)
# pylint: disable=protected-access
cached_build_dir
=
os
.
path
.
join
(
build_top_dir
,
f
'
{
source_digest
}
-
{
_get_mangled_gpu_name
()
}
'
)
if
not
os
.
path
.
isdir
(
cached_build_dir
):
tmpdir
=
f
'
{
build_top_dir
}
/srctmp-
{
uuid
.
uuid4
().
hex
}
'
os
.
makedirs
(
tmpdir
)
for
src
in
all_source_files
:
shutil
.
copyfile
(
src
,
os
.
path
.
join
(
tmpdir
,
os
.
path
.
basename
(
src
)))
try
:
os
.
replace
(
tmpdir
,
cached_build_dir
)
# atomic
except
OSError
:
# source directory already exists, delete tmpdir and its contents.
shutil
.
rmtree
(
tmpdir
)
if
not
os
.
path
.
isdir
(
cached_build_dir
):
raise
# Compile.
cached_sources
=
[
os
.
path
.
join
(
cached_build_dir
,
os
.
path
.
basename
(
fname
))
for
fname
in
sources
]
torch
.
utils
.
cpp_extension
.
load
(
name
=
module_name
,
build_directory
=
cached_build_dir
,
verbose
=
verbose_build
,
sources
=
cached_sources
,
**
build_kwargs
)
else
:
torch
.
utils
.
cpp_extension
.
load
(
name
=
module_name
,
verbose
=
verbose_build
,
sources
=
sources
,
**
build_kwargs
)
# Load.
module
=
importlib
.
import_module
(
module_name
)
except
:
if
verbosity
==
'brief'
:
print
(
'Failed!'
)
raise
# Print status and add to cache dict.
if
verbosity
==
'full'
:
print
(
f
'Done setting up PyTorch plugin "
{
module_name
}
".'
)
elif
verbosity
==
'brief'
:
print
(
'Done.'
)
_cached_plugins
[
module_name
]
=
module
return
module
\ No newline at end of file
stylegan_human/torch_utils/misc.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
re
import
contextlib
import
numpy
as
np
import
torch
import
warnings
import
dnnlib
#----------------------------------------------------------------------------
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.
_constant_cache
=
dict
()
def
constant
(
value
,
shape
=
None
,
dtype
=
None
,
device
=
None
,
memory_format
=
None
):
value
=
np
.
asarray
(
value
)
if
shape
is
not
None
:
shape
=
tuple
(
shape
)
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
if
memory_format
is
None
:
memory_format
=
torch
.
contiguous_format
key
=
(
value
.
shape
,
value
.
dtype
,
value
.
tobytes
(),
shape
,
dtype
,
device
,
memory_format
)
tensor
=
_constant_cache
.
get
(
key
,
None
)
if
tensor
is
None
:
tensor
=
torch
.
as_tensor
(
value
.
copy
(),
dtype
=
dtype
,
device
=
device
)
if
shape
is
not
None
:
tensor
,
_
=
torch
.
broadcast_tensors
(
tensor
,
torch
.
empty
(
shape
))
tensor
=
tensor
.
contiguous
(
memory_format
=
memory_format
)
_constant_cache
[
key
]
=
tensor
return
tensor
#----------------------------------------------------------------------------
# Replace NaN/Inf with specified numerical values.
try
:
nan_to_num
=
torch
.
nan_to_num
# 1.8.0a0
except
AttributeError
:
def
nan_to_num
(
input
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
,
*
,
out
=
None
):
# pylint: disable=redefined-builtin
assert
isinstance
(
input
,
torch
.
Tensor
)
if
posinf
is
None
:
posinf
=
torch
.
finfo
(
input
.
dtype
).
max
if
neginf
is
None
:
neginf
=
torch
.
finfo
(
input
.
dtype
).
min
assert
nan
==
0
return
torch
.
clamp
(
input
.
unsqueeze
(
0
).
nansum
(
0
),
min
=
neginf
,
max
=
posinf
,
out
=
out
)
#----------------------------------------------------------------------------
# Symbolic assert.
try
:
symbolic_assert
=
torch
.
_assert
# 1.8.0a0 # pylint: disable=protected-access
except
AttributeError
:
symbolic_assert
=
torch
.
Assert
# 1.7.0
#----------------------------------------------------------------------------
# Context manager to suppress known warnings in torch.jit.trace().
class
suppress_tracer_warnings
(
warnings
.
catch_warnings
):
def
__enter__
(
self
):
super
().
__enter__
()
warnings
.
simplefilter
(
'ignore'
,
category
=
torch
.
jit
.
TracerWarning
)
return
self
#----------------------------------------------------------------------------
# Assert that the shape of a tensor matches the given list of integers.
# None indicates that the size of a dimension is allowed to vary.
# Performs symbolic assertion when used in torch.jit.trace().
def
assert_shape
(
tensor
,
ref_shape
):
if
tensor
.
ndim
!=
len
(
ref_shape
):
raise
AssertionError
(
f
'Wrong number of dimensions: got
{
tensor
.
ndim
}
, expected
{
len
(
ref_shape
)
}
'
)
for
idx
,
(
size
,
ref_size
)
in
enumerate
(
zip
(
tensor
.
shape
,
ref_shape
)):
if
ref_size
is
None
:
pass
elif
isinstance
(
ref_size
,
torch
.
Tensor
):
with
suppress_tracer_warnings
():
# as_tensor results are registered as constants
symbolic_assert
(
torch
.
equal
(
torch
.
as_tensor
(
size
),
ref_size
),
f
'Wrong size for dimension
{
idx
}
'
)
elif
isinstance
(
size
,
torch
.
Tensor
):
with
suppress_tracer_warnings
():
# as_tensor results are registered as constants
symbolic_assert
(
torch
.
equal
(
size
,
torch
.
as_tensor
(
ref_size
)),
f
'Wrong size for dimension
{
idx
}
: expected
{
ref_size
}
'
)
elif
size
!=
ref_size
:
raise
AssertionError
(
f
'Wrong size for dimension
{
idx
}
: got
{
size
}
, expected
{
ref_size
}
'
)
#----------------------------------------------------------------------------
# Function decorator that calls torch.autograd.profiler.record_function().
def
profiled_function
(
fn
):
def
decorator
(
*
args
,
**
kwargs
):
with
torch
.
autograd
.
profiler
.
record_function
(
fn
.
__name__
):
return
fn
(
*
args
,
**
kwargs
)
decorator
.
__name__
=
fn
.
__name__
return
decorator
#----------------------------------------------------------------------------
# Sampler for torch.utils.data.DataLoader that loops over the dataset
# indefinitely, shuffling items as it goes.
class
InfiniteSampler
(
torch
.
utils
.
data
.
Sampler
):
def
__init__
(
self
,
dataset
,
rank
=
0
,
num_replicas
=
1
,
shuffle
=
True
,
seed
=
0
,
window_size
=
0.5
):
assert
len
(
dataset
)
>
0
assert
num_replicas
>
0
assert
0
<=
rank
<
num_replicas
assert
0
<=
window_size
<=
1
super
().
__init__
(
dataset
)
self
.
dataset
=
dataset
self
.
rank
=
rank
self
.
num_replicas
=
num_replicas
self
.
shuffle
=
shuffle
self
.
seed
=
seed
self
.
window_size
=
window_size
def
__iter__
(
self
):
order
=
np
.
arange
(
len
(
self
.
dataset
))
rnd
=
None
window
=
0
if
self
.
shuffle
:
rnd
=
np
.
random
.
RandomState
(
self
.
seed
)
rnd
.
shuffle
(
order
)
window
=
int
(
np
.
rint
(
order
.
size
*
self
.
window_size
))
idx
=
0
while
True
:
i
=
idx
%
order
.
size
if
idx
%
self
.
num_replicas
==
self
.
rank
:
yield
order
[
i
]
if
window
>=
2
:
j
=
(
i
-
rnd
.
randint
(
window
))
%
order
.
size
order
[
i
],
order
[
j
]
=
order
[
j
],
order
[
i
]
idx
+=
1
#----------------------------------------------------------------------------
# Utilities for operating with torch.nn.Module parameters and buffers.
def
params_and_buffers
(
module
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
return
list
(
module
.
parameters
())
+
list
(
module
.
buffers
())
def
named_params_and_buffers
(
module
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
return
list
(
module
.
named_parameters
())
+
list
(
module
.
named_buffers
())
def
copy_params_and_buffers
(
src_module
,
dst_module
,
require_all
=
False
):
assert
isinstance
(
src_module
,
torch
.
nn
.
Module
)
assert
isinstance
(
dst_module
,
torch
.
nn
.
Module
)
src_tensors
=
{
name
:
tensor
for
name
,
tensor
in
named_params_and_buffers
(
src_module
)}
for
name
,
tensor
in
named_params_and_buffers
(
dst_module
):
assert
(
name
in
src_tensors
)
or
(
not
require_all
)
if
name
in
src_tensors
:
tensor
.
copy_
(
src_tensors
[
name
].
detach
()).
requires_grad_
(
tensor
.
requires_grad
)
#----------------------------------------------------------------------------
# Context manager for easily enabling/disabling DistributedDataParallel
# synchronization.
@
contextlib
.
contextmanager
def
ddp_sync
(
module
,
sync
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
sync
or
not
isinstance
(
module
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
yield
else
:
with
module
.
no_sync
():
yield
#----------------------------------------------------------------------------
# Check DistributedDataParallel consistency across processes.
def
check_ddp_consistency
(
module
,
ignore_regex
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
name
,
tensor
in
named_params_and_buffers
(
module
):
fullname
=
type
(
module
).
__name__
+
'.'
+
name
if
ignore_regex
is
not
None
and
re
.
fullmatch
(
ignore_regex
,
fullname
):
continue
tensor
=
tensor
.
detach
()
other
=
tensor
.
clone
()
torch
.
distributed
.
broadcast
(
tensor
=
other
,
src
=
0
)
assert
(
nan_to_num
(
tensor
)
==
nan_to_num
(
other
)).
all
(),
fullname
#----------------------------------------------------------------------------
# Print summary table of module hierarchy.
def
print_module_summary
(
module
,
inputs
,
max_nesting
=
3
,
skip_redundant
=
True
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
assert
not
isinstance
(
module
,
torch
.
jit
.
ScriptModule
)
assert
isinstance
(
inputs
,
(
tuple
,
list
))
# Register hooks.
entries
=
[]
nesting
=
[
0
]
def
pre_hook
(
_mod
,
_inputs
):
nesting
[
0
]
+=
1
def
post_hook
(
mod
,
_inputs
,
outputs
):
nesting
[
0
]
-=
1
if
nesting
[
0
]
<=
max_nesting
:
outputs
=
list
(
outputs
)
if
isinstance
(
outputs
,
(
tuple
,
list
))
else
[
outputs
]
outputs
=
[
t
for
t
in
outputs
if
isinstance
(
t
,
torch
.
Tensor
)]
entries
.
append
(
dnnlib
.
EasyDict
(
mod
=
mod
,
outputs
=
outputs
))
hooks
=
[
mod
.
register_forward_pre_hook
(
pre_hook
)
for
mod
in
module
.
modules
()]
hooks
+=
[
mod
.
register_forward_hook
(
post_hook
)
for
mod
in
module
.
modules
()]
# Run module.
outputs
=
module
(
*
inputs
)
for
hook
in
hooks
:
hook
.
remove
()
# Identify unique outputs, parameters, and buffers.
tensors_seen
=
set
()
for
e
in
entries
:
e
.
unique_params
=
[
t
for
t
in
e
.
mod
.
parameters
()
if
id
(
t
)
not
in
tensors_seen
]
e
.
unique_buffers
=
[
t
for
t
in
e
.
mod
.
buffers
()
if
id
(
t
)
not
in
tensors_seen
]
e
.
unique_outputs
=
[
t
for
t
in
e
.
outputs
if
id
(
t
)
not
in
tensors_seen
]
tensors_seen
|=
{
id
(
t
)
for
t
in
e
.
unique_params
+
e
.
unique_buffers
+
e
.
unique_outputs
}
# Filter out redundant entries.
if
skip_redundant
:
entries
=
[
e
for
e
in
entries
if
len
(
e
.
unique_params
)
or
len
(
e
.
unique_buffers
)
or
len
(
e
.
unique_outputs
)]
# Construct table.
rows
=
[[
type
(
module
).
__name__
,
'Parameters'
,
'Buffers'
,
'Output shape'
,
'Datatype'
]]
rows
+=
[[
'---'
]
*
len
(
rows
[
0
])]
param_total
=
0
buffer_total
=
0
submodule_names
=
{
mod
:
name
for
name
,
mod
in
module
.
named_modules
()}
for
e
in
entries
:
name
=
'<top-level>'
if
e
.
mod
is
module
else
submodule_names
[
e
.
mod
]
param_size
=
sum
(
t
.
numel
()
for
t
in
e
.
unique_params
)
buffer_size
=
sum
(
t
.
numel
()
for
t
in
e
.
unique_buffers
)
output_shapes
=
[
str
(
list
(
e
.
outputs
[
0
].
shape
))
for
t
in
e
.
outputs
]
output_dtypes
=
[
str
(
t
.
dtype
).
split
(
'.'
)[
-
1
]
for
t
in
e
.
outputs
]
rows
+=
[[
name
+
(
':0'
if
len
(
e
.
outputs
)
>=
2
else
''
),
str
(
param_size
)
if
param_size
else
'-'
,
str
(
buffer_size
)
if
buffer_size
else
'-'
,
(
output_shapes
+
[
'-'
])[
0
],
(
output_dtypes
+
[
'-'
])[
0
],
]]
for
idx
in
range
(
1
,
len
(
e
.
outputs
)):
rows
+=
[[
name
+
f
':
{
idx
}
'
,
'-'
,
'-'
,
output_shapes
[
idx
],
output_dtypes
[
idx
]]]
param_total
+=
param_size
buffer_total
+=
buffer_size
rows
+=
[[
'---'
]
*
len
(
rows
[
0
])]
rows
+=
[[
'Total'
,
str
(
param_total
),
str
(
buffer_total
),
'-'
,
'-'
]]
# Print table.
widths
=
[
max
(
len
(
cell
)
for
cell
in
column
)
for
column
in
zip
(
*
rows
)]
print
()
for
row
in
rows
:
print
(
' '
.
join
(
cell
+
' '
*
(
width
-
len
(
cell
))
for
cell
,
width
in
zip
(
row
,
widths
)))
print
()
return
outputs
#----------------------------------------------------------------------------
stylegan_human/torch_utils/models.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
# https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
import
math
import
random
import
functools
import
operator
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
import
torch.nn.init
as
init
from
torch.autograd
import
Function
from
.op_edit
import
FusedLeakyReLU
,
fused_leaky_relu
,
upfirdn2d
class
PixelNorm
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
):
return
input
*
torch
.
rsqrt
(
torch
.
mean
(
input
**
2
,
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
def
make_kernel
(
k
):
k
=
torch
.
tensor
(
k
,
dtype
=
torch
.
float32
)
if
k
.
ndim
==
1
:
k
=
k
[
None
,
:]
*
k
[:,
None
]
k
/=
k
.
sum
()
return
k
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
*
(
factor
**
2
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
return
out
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
return
out
class
Blur
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
kernel
=
make_kernel
(
kernel
)
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
**
2
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
pad
=
pad
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
pad
=
self
.
pad
)
return
out
class
EqualConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
)
)
self
.
scale
=
1
/
math
.
sqrt
(
in_channel
*
kernel_size
**
2
)
self
.
stride
=
stride
self
.
padding
=
padding
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channel
))
else
:
self
.
bias
=
None
def
forward
(
self
,
input
):
out
=
F
.
conv2d
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
)
return
out
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
,"
f
"
{
self
.
weight
.
shape
[
2
]
}
, stride=
{
self
.
stride
}
, padding=
{
self
.
padding
}
)"
)
class
EqualLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
bias
=
True
,
bias_init
=
0
,
lr_mul
=
1
,
activation
=
None
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_dim
,
in_dim
).
div_
(
lr_mul
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_dim
).
fill_
(
bias_init
))
else
:
self
.
bias
=
None
self
.
activation
=
activation
self
.
scale
=
(
1
/
math
.
sqrt
(
in_dim
))
*
lr_mul
self
.
lr_mul
=
lr_mul
def
forward
(
self
,
input
):
if
self
.
activation
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
)
out
=
fused_leaky_relu
(
out
,
self
.
bias
*
self
.
lr_mul
)
else
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
*
self
.
lr_mul
)
return
out
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
)"
)
class
ScaledLeakyReLU
(
nn
.
Module
):
def
__init__
(
self
,
negative_slope
=
0.2
):
super
().
__init__
()
self
.
negative_slope
=
negative_slope
def
forward
(
self
,
input
):
out
=
F
.
leaky_relu
(
input
,
negative_slope
=
self
.
negative_slope
)
return
out
*
math
.
sqrt
(
2
)
class
ModulatedConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
):
super
().
__init__
()
self
.
eps
=
1e-8
self
.
kernel_size
=
kernel_size
self
.
in_channel
=
in_channel
self
.
out_channel
=
out_channel
self
.
upsample
=
upsample
self
.
downsample
=
downsample
if
upsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
),
upsample_factor
=
factor
)
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
fan_in
=
in_channel
*
kernel_size
**
2
self
.
scale
=
1
/
math
.
sqrt
(
fan_in
)
self
.
padding
=
kernel_size
//
2
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channel
,
in_channel
,
kernel_size
,
kernel_size
)
)
self
.
modulation
=
EqualLinear
(
style_dim
,
in_channel
,
bias_init
=
1
)
self
.
demodulate
=
demodulate
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
in_channel
}
,
{
self
.
out_channel
}
,
{
self
.
kernel_size
}
, "
f
"upsample=
{
self
.
upsample
}
, downsample=
{
self
.
downsample
}
)"
)
def
forward
(
self
,
input
,
style
):
batch
,
in_channel
,
height
,
width
=
input
.
shape
style
=
self
.
modulation
(
style
).
view
(
batch
,
1
,
in_channel
,
1
,
1
)
weight
=
self
.
scale
*
self
.
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
1e-8
)
weight
=
weight
*
demod
.
view
(
batch
,
self
.
out_channel
,
1
,
1
,
1
)
weight
=
weight
.
view
(
batch
*
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
if
self
.
upsample
:
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
weight
=
weight
.
view
(
batch
,
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
batch
*
in_channel
,
self
.
out_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
out
=
F
.
conv_transpose2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
out
=
self
.
blur
(
out
)
elif
self
.
downsample
:
input
=
self
.
blur
(
input
)
_
,
_
,
height
,
width
=
input
.
shape
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
else
:
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
self
.
padding
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
return
out
class
NoiseInjection
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
image
,
noise
=
None
):
if
noise
is
None
:
batch
,
_
,
height
,
width
=
image
.
shape
noise
=
image
.
new_empty
(
batch
,
1
,
height
,
width
).
normal_
()
return
image
+
self
.
weight
*
noise
class
ConstantInput
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
size
=
4
):
super
().
__init__
()
self
.
input
=
nn
.
Parameter
(
torch
.
randn
(
1
,
channel
,
size
,
size
//
2
))
def
forward
(
self
,
input
):
batch
=
input
.
shape
[
0
]
out
=
self
.
input
.
repeat
(
batch
,
1
,
1
,
1
)
return
out
class
StyledConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
):
super
().
__init__
()
self
.
conv
=
ModulatedConv2d
(
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
demodulate
=
demodulate
,
)
self
.
noise
=
NoiseInjection
()
self
.
activate
=
FusedLeakyReLU
(
out_channel
)
def
forward
(
self
,
input
,
style
,
noise
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
self
.
noise
(
out
,
noise
=
noise
)
out
=
self
.
activate
(
out
)
return
out
class
ToRGB
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
style_dim
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
if
upsample
:
self
.
upsample
=
Upsample
(
blur_kernel
)
self
.
conv
=
ModulatedConv2d
(
in_channel
,
3
,
1
,
style_dim
,
demodulate
=
False
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
1
,
1
))
def
forward
(
self
,
input
,
style
,
skip
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
out
+
self
.
bias
if
skip
is
not
None
:
skip
=
self
.
upsample
(
skip
)
out
=
out
+
skip
return
out
class
Generator
(
nn
.
Module
):
def
__init__
(
self
,
size
,
style_dim
,
n_mlp
,
channel_multiplier
=
1
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
lr_mlp
=
0.01
,
small
=
False
,
small_isaac
=
False
,
):
super
().
__init__
()
self
.
size
=
size
if
small
and
size
>
64
:
raise
ValueError
(
"small only works for sizes <= 64"
)
self
.
style_dim
=
style_dim
layers
=
[
PixelNorm
()]
for
i
in
range
(
n_mlp
):
layers
.
append
(
EqualLinear
(
style_dim
,
style_dim
,
lr_mul
=
lr_mlp
,
activation
=
"fused_lrelu"
)
)
self
.
style
=
nn
.
Sequential
(
*
layers
)
if
small
:
self
.
channels
=
{
4
:
64
*
channel_multiplier
,
8
:
64
*
channel_multiplier
,
16
:
64
*
channel_multiplier
,
32
:
64
*
channel_multiplier
,
64
:
64
*
channel_multiplier
,
}
elif
small_isaac
:
self
.
channels
=
{
4
:
256
,
8
:
256
,
16
:
256
,
32
:
256
,
64
:
128
,
128
:
128
}
else
:
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
self
.
input
=
ConstantInput
(
self
.
channels
[
4
])
self
.
conv1
=
StyledConv
(
self
.
channels
[
4
],
self
.
channels
[
4
],
3
,
style_dim
,
blur_kernel
=
blur_kernel
)
self
.
to_rgb1
=
ToRGB
(
self
.
channels
[
4
],
style_dim
,
upsample
=
False
)
self
.
log_size
=
int
(
math
.
log
(
size
,
2
))
self
.
num_layers
=
(
self
.
log_size
-
2
)
*
2
+
1
self
.
convs
=
nn
.
ModuleList
()
self
.
upsamples
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
self
.
noises
=
nn
.
Module
()
in_channel
=
self
.
channels
[
4
]
for
layer_idx
in
range
(
self
.
num_layers
):
res
=
(
layer_idx
+
5
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
//
2
]
self
.
noises
.
register_buffer
(
"noise_{}"
.
format
(
layer_idx
),
torch
.
randn
(
*
shape
)
)
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channel
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
StyledConv
(
in_channel
,
out_channel
,
3
,
style_dim
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
)
)
self
.
convs
.
append
(
StyledConv
(
out_channel
,
out_channel
,
3
,
style_dim
,
blur_kernel
=
blur_kernel
)
)
self
.
to_rgbs
.
append
(
ToRGB
(
out_channel
,
style_dim
))
in_channel
=
out_channel
self
.
n_latent
=
self
.
log_size
*
2
-
2
def
make_noise
(
self
):
device
=
self
.
input
.
input
.
device
noises
=
[
torch
.
randn
(
1
,
1
,
2
**
2
,
2
**
2
//
2
,
device
=
device
)]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
//
2
,
device
=
device
))
return
noises
def
mean_latent
(
self
,
n_latent
):
latent_in
=
torch
.
randn
(
n_latent
,
self
.
style_dim
,
device
=
self
.
input
.
input
.
device
)
latent
=
self
.
style
(
latent_in
).
mean
(
0
,
keepdim
=
True
)
return
latent
def
get_latent
(
self
,
input
):
return
self
.
style
(
input
)
def
forward
(
self
,
styles
,
return_latents
=
False
,
return_features
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
noise
=
None
,
randomize_noise
=
True
,
real
=
False
,
):
if
not
input_is_latent
:
styles
=
[
self
.
style
(
s
)
for
s
in
styles
]
if
noise
is
None
:
if
randomize_noise
:
noise
=
[
None
]
*
self
.
num_layers
else
:
noise
=
[
getattr
(
self
.
noises
,
"noise_{}"
.
format
(
i
))
for
i
in
range
(
self
.
num_layers
)
]
if
truncation
<
1
:
# print('truncation_latent: ', truncation_latent.shape)
if
not
real
:
#if type(styles) == list:
style_t
=
[]
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
)
)
# (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
styles
=
style_t
else
:
# styles are latent (tensor: 1,18,512), for real PTI output
truncation_latent
=
truncation_latent
.
repeat
(
18
,
1
).
unsqueeze
(
0
)
# (1,512) --> (1,18,512)
styles
=
torch
.
add
(
truncation_latent
,
torch
.
mul
(
torch
.
sub
(
styles
,
truncation_latent
),
truncation
))
# print('now styles after truncation : ', styles)
#if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
if
not
real
:
if
len
(
styles
)
<
2
:
inject_index
=
self
.
n_latent
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
elif
type
(
styles
)
==
list
:
if
inject_index
is
None
:
inject_index
=
4
latent
=
styles
[
0
].
unsqueeze
(
0
)
if
latent
.
shape
[
1
]
==
1
:
latent
=
latent
.
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
latent
[:,
:
inject_index
,
:]
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
n_latent
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
else
:
# input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
latent
=
styles
# print(f'processed latent: {latent.shape}')
features
=
{}
out
=
self
.
input
(
latent
)
features
[
"out_0"
]
=
out
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
noise
[
0
])
features
[
"conv1_0"
]
=
out
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
features
[
"skip_0"
]
=
skip
i
=
1
for
conv1
,
conv2
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
noise
[
1
::
2
],
noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
conv1
(
out
,
latent
[:,
i
],
noise
=
noise1
)
features
[
"conv1_{}"
.
format
(
i
)]
=
out
out
=
conv2
(
out
,
latent
[:,
i
+
1
],
noise
=
noise2
)
features
[
"conv2_{}"
.
format
(
i
)]
=
out
skip
=
to_rgb
(
out
,
latent
[:,
i
+
2
],
skip
)
features
[
"skip_{}"
.
format
(
i
)]
=
skip
i
+=
2
image
=
skip
if
return_latents
:
return
image
,
latent
elif
return_features
:
return
image
,
features
else
:
return
image
,
None
class
ConvLayer
(
nn
.
Sequential
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
bias
=
True
,
activate
=
True
,
):
layers
=
[]
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
layers
.
append
(
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
self
.
padding
=
0
else
:
stride
=
1
self
.
padding
=
kernel_size
//
2
layers
.
append
(
EqualConv2d
(
in_channel
,
out_channel
,
kernel_size
,
padding
=
self
.
padding
,
stride
=
stride
,
bias
=
bias
and
not
activate
,
)
)
if
activate
:
if
bias
:
layers
.
append
(
FusedLeakyReLU
(
out_channel
))
else
:
layers
.
append
(
ScaledLeakyReLU
(
0.2
))
super
().
__init__
(
*
layers
)
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
self
.
conv1
=
ConvLayer
(
in_channel
,
in_channel
,
3
)
self
.
conv2
=
ConvLayer
(
in_channel
,
out_channel
,
3
,
downsample
=
True
)
self
.
skip
=
ConvLayer
(
in_channel
,
out_channel
,
1
,
downsample
=
True
,
activate
=
False
,
bias
=
False
)
def
forward
(
self
,
input
):
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
math
.
sqrt
(
2
)
return
out
class
StyleDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
small
=
False
):
super
().
__init__
()
if
small
:
channels
=
{
4
:
64
,
8
:
64
,
16
:
64
,
32
:
64
,
64
:
64
}
else
:
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
log_size
=
int
(
math
.
log
(
size
,
2
))
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
,
blur_kernel
))
in_channel
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
stddev_group
=
4
self
.
stddev_feat
=
1
self
.
final_conv
=
ConvLayer
(
in_channel
+
1
,
channels
[
4
],
3
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinear
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
activation
=
"fused_lrelu"
),
EqualLinear
(
channels
[
4
],
1
),
)
def
forward
(
self
,
input
):
h
=
input
h_list
=
[]
for
index
,
blocklist
in
enumerate
(
self
.
convs
):
h
=
blocklist
(
h
)
h_list
.
append
(
h
)
out
=
h
batch
,
channel
,
height
,
width
=
out
.
shape
group
=
min
(
batch
,
self
.
stddev_group
)
stddev
=
out
.
view
(
group
,
-
1
,
self
.
stddev_feat
,
channel
//
self
.
stddev_feat
,
height
,
width
)
stddev
=
torch
.
sqrt
(
stddev
.
var
(
0
,
unbiased
=
False
)
+
1e-8
)
stddev
=
stddev
.
mean
([
2
,
3
,
4
],
keepdims
=
True
).
squeeze
(
2
)
stddev
=
stddev
.
repeat
(
group
,
1
,
height
,
width
)
out
=
torch
.
cat
([
out
,
stddev
],
1
)
out
=
self
.
final_conv
(
out
)
h_list
.
append
(
out
)
out
=
out
.
view
(
batch
,
-
1
)
out
=
self
.
final_linear
(
out
)
return
out
,
h_list
class
StyleEncoder
(
nn
.
Module
):
def
__init__
(
self
,
size
,
w_dim
=
512
):
super
().
__init__
()
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
}
self
.
w_dim
=
w_dim
log_size
=
int
(
math
.
log
(
size
,
2
))
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
))
in_channel
=
out_channel
convs
.
append
(
EqualConv2d
(
in_channel
,
2
*
self
.
w_dim
,
4
,
padding
=
0
,
bias
=
False
))
self
.
convs
=
nn
.
Sequential
(
*
convs
)
def
forward
(
self
,
input
):
out
=
self
.
convs
(
input
)
# return out.view(len(input), self.n_latents, self.w_dim)
reshaped
=
out
.
view
(
len
(
input
),
2
*
self
.
w_dim
)
return
reshaped
[:,:
self
.
w_dim
],
reshaped
[:,
self
.
w_dim
:]
def
kaiming_init
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Conv2d
)):
init
.
kaiming_normal_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
)):
m
.
weight
.
data
.
fill_
(
1
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
def
normal_init
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Conv2d
)):
init
.
normal_
(
m
.
weight
,
0
,
0.02
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
)):
m
.
weight
.
data
.
fill_
(
1
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
\ No newline at end of file
stylegan_human/torch_utils/models_face.py
0 → 100644
View file @
fba8bde8
# Copyright (c) SenseTime Research. All rights reserved.
import
math
import
random
import
functools
import
operator
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
import
torch.nn.init
as
init
from
torch.autograd
import
Function
from
.op_edit
import
FusedLeakyReLU
,
fused_leaky_relu
,
upfirdn2d
class
PixelNorm
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
):
return
input
*
torch
.
rsqrt
(
torch
.
mean
(
input
**
2
,
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
def
make_kernel
(
k
):
k
=
torch
.
tensor
(
k
,
dtype
=
torch
.
float32
)
if
k
.
ndim
==
1
:
k
=
k
[
None
,
:]
*
k
[:,
None
]
k
/=
k
.
sum
()
return
k
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
*
(
factor
**
2
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
return
out
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
return
out
class
Blur
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
kernel
=
make_kernel
(
kernel
)
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
**
2
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
pad
=
pad
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
pad
=
self
.
pad
)
return
out
class
EqualConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
)
)
self
.
scale
=
1
/
math
.
sqrt
(
in_channel
*
kernel_size
**
2
)
self
.
stride
=
stride
self
.
padding
=
padding
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channel
))
else
:
self
.
bias
=
None
def
forward
(
self
,
input
):
out
=
F
.
conv2d
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
)
return
out
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
,"
f
"
{
self
.
weight
.
shape
[
2
]
}
, stride=
{
self
.
stride
}
, padding=
{
self
.
padding
}
)"
)
class
EqualLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
bias
=
True
,
bias_init
=
0
,
lr_mul
=
1
,
activation
=
None
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_dim
,
in_dim
).
div_
(
lr_mul
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_dim
).
fill_
(
bias_init
))
else
:
self
.
bias
=
None
self
.
activation
=
activation
self
.
scale
=
(
1
/
math
.
sqrt
(
in_dim
))
*
lr_mul
self
.
lr_mul
=
lr_mul
def
forward
(
self
,
input
):
if
self
.
activation
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
)
out
=
fused_leaky_relu
(
out
,
self
.
bias
*
self
.
lr_mul
)
else
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
*
self
.
lr_mul
)
return
out
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
)"
)
class
ScaledLeakyReLU
(
nn
.
Module
):
def
__init__
(
self
,
negative_slope
=
0.2
):
super
().
__init__
()
self
.
negative_slope
=
negative_slope
def
forward
(
self
,
input
):
out
=
F
.
leaky_relu
(
input
,
negative_slope
=
self
.
negative_slope
)
return
out
*
math
.
sqrt
(
2
)
class
ModulatedConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
):
super
().
__init__
()
self
.
eps
=
1e-8
self
.
kernel_size
=
kernel_size
self
.
in_channel
=
in_channel
self
.
out_channel
=
out_channel
self
.
upsample
=
upsample
self
.
downsample
=
downsample
if
upsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
),
upsample_factor
=
factor
)
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
fan_in
=
in_channel
*
kernel_size
**
2
self
.
scale
=
1
/
math
.
sqrt
(
fan_in
)
self
.
padding
=
kernel_size
//
2
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channel
,
in_channel
,
kernel_size
,
kernel_size
)
)
self
.
modulation
=
EqualLinear
(
style_dim
,
in_channel
,
bias_init
=
1
)
self
.
demodulate
=
demodulate
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
in_channel
}
,
{
self
.
out_channel
}
,
{
self
.
kernel_size
}
, "
f
"upsample=
{
self
.
upsample
}
, downsample=
{
self
.
downsample
}
)"
)
def
forward
(
self
,
input
,
style
):
batch
,
in_channel
,
height
,
width
=
input
.
shape
style
=
self
.
modulation
(
style
).
view
(
batch
,
1
,
in_channel
,
1
,
1
)
weight
=
self
.
scale
*
self
.
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
1e-8
)
weight
=
weight
*
demod
.
view
(
batch
,
self
.
out_channel
,
1
,
1
,
1
)
weight
=
weight
.
view
(
batch
*
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
if
self
.
upsample
:
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
weight
=
weight
.
view
(
batch
,
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
batch
*
in_channel
,
self
.
out_channel
,
self
.
kernel_size
,
self
.
kernel_size
)
out
=
F
.
conv_transpose2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
out
=
self
.
blur
(
out
)
elif
self
.
downsample
:
input
=
self
.
blur
(
input
)
_
,
_
,
height
,
width
=
input
.
shape
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
else
:
input
=
input
.
view
(
1
,
batch
*
in_channel
,
height
,
width
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
self
.
padding
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
view
(
batch
,
self
.
out_channel
,
height
,
width
)
return
out
class
NoiseInjection
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
image
,
noise
=
None
):
if
noise
is
None
:
batch
,
_
,
height
,
width
=
image
.
shape
noise
=
image
.
new_empty
(
batch
,
1
,
height
,
width
).
normal_
()
return
image
+
self
.
weight
*
noise
class
ConstantInput
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
size
=
4
):
super
().
__init__
()
self
.
input
=
nn
.
Parameter
(
torch
.
randn
(
1
,
channel
,
size
,
size
))
def
forward
(
self
,
input
):
batch
=
input
.
shape
[
0
]
out
=
self
.
input
.
repeat
(
batch
,
1
,
1
,
1
)
return
out
class
StyledConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
):
super
().
__init__
()
self
.
conv
=
ModulatedConv2d
(
in_channel
,
out_channel
,
kernel_size
,
style_dim
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
demodulate
=
demodulate
,
)
self
.
noise
=
NoiseInjection
()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self
.
activate
=
FusedLeakyReLU
(
out_channel
)
def
forward
(
self
,
input
,
style
,
noise
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
self
.
noise
(
out
,
noise
=
noise
)
# out = out + self.bias
out
=
self
.
activate
(
out
)
return
out
class
ToRGB
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
style_dim
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
if
upsample
:
self
.
upsample
=
Upsample
(
blur_kernel
)
self
.
conv
=
ModulatedConv2d
(
in_channel
,
3
,
1
,
style_dim
,
demodulate
=
False
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
1
,
1
))
def
forward
(
self
,
input
,
style
,
skip
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
out
+
self
.
bias
if
skip
is
not
None
:
skip
=
self
.
upsample
(
skip
)
out
=
out
+
skip
return
out
class
Generator
(
nn
.
Module
):
def
__init__
(
self
,
size
,
style_dim
,
n_mlp
,
channel_multiplier
=
1
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
lr_mlp
=
0.01
,
small
=
False
,
small_isaac
=
False
,
):
super
().
__init__
()
self
.
size
=
size
if
small
and
size
>
64
:
raise
ValueError
(
"small only works for sizes <= 64"
)
self
.
style_dim
=
style_dim
layers
=
[
PixelNorm
()]
for
i
in
range
(
n_mlp
):
layers
.
append
(
EqualLinear
(
style_dim
,
style_dim
,
lr_mul
=
lr_mlp
,
activation
=
"fused_lrelu"
)
)
self
.
style
=
nn
.
Sequential
(
*
layers
)
if
small
:
self
.
channels
=
{
4
:
64
*
channel_multiplier
,
8
:
64
*
channel_multiplier
,
16
:
64
*
channel_multiplier
,
32
:
64
*
channel_multiplier
,
64
:
64
*
channel_multiplier
,
}
elif
small_isaac
:
self
.
channels
=
{
4
:
256
,
8
:
256
,
16
:
256
,
32
:
256
,
64
:
128
,
128
:
128
}
else
:
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
self
.
input
=
ConstantInput
(
self
.
channels
[
4
])
self
.
conv1
=
StyledConv
(
self
.
channels
[
4
],
self
.
channels
[
4
],
3
,
style_dim
,
blur_kernel
=
blur_kernel
)
self
.
to_rgb1
=
ToRGB
(
self
.
channels
[
4
],
style_dim
,
upsample
=
False
)
self
.
log_size
=
int
(
math
.
log
(
size
,
2
))
self
.
num_layers
=
(
self
.
log_size
-
2
)
*
2
+
1
self
.
convs
=
nn
.
ModuleList
()
self
.
upsamples
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
self
.
noises
=
nn
.
Module
()
in_channel
=
self
.
channels
[
4
]
for
layer_idx
in
range
(
self
.
num_layers
):
res
=
(
layer_idx
+
5
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
noises
.
register_buffer
(
"noise_{}"
.
format
(
layer_idx
),
torch
.
randn
(
*
shape
)
)
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channel
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
StyledConv
(
in_channel
,
out_channel
,
3
,
style_dim
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
)
)
self
.
convs
.
append
(
StyledConv
(
out_channel
,
out_channel
,
3
,
style_dim
,
blur_kernel
=
blur_kernel
)
)
self
.
to_rgbs
.
append
(
ToRGB
(
out_channel
,
style_dim
))
in_channel
=
out_channel
self
.
n_latent
=
self
.
log_size
*
2
-
2
def
make_noise
(
self
):
device
=
self
.
input
.
input
.
device
noises
=
[
torch
.
randn
(
1
,
1
,
2
**
2
,
2
**
2
,
device
=
device
)]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
,
device
=
device
))
return
noises
def
mean_latent
(
self
,
n_latent
):
latent_in
=
torch
.
randn
(
n_latent
,
self
.
style_dim
,
device
=
self
.
input
.
input
.
device
)
latent
=
self
.
style
(
latent_in
).
mean
(
0
,
keepdim
=
True
)
return
latent
def
get_latent
(
self
,
input
):
return
self
.
style
(
input
)
def
forward
(
self
,
styles
,
return_latents
=
False
,
return_features
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
noise
=
None
,
randomize_noise
=
True
,
):
if
not
input_is_latent
:
# print("haha")
styles
=
[
self
.
style
(
s
)
for
s
in
styles
]
if
noise
is
None
:
if
randomize_noise
:
noise
=
[
None
]
*
self
.
num_layers
else
:
noise
=
[
getattr
(
self
.
noises
,
"noise_{}"
.
format
(
i
))
for
i
in
range
(
self
.
num_layers
)
]
if
truncation
<
1
:
style_t
=
[]
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
)
)
styles
=
style_t
# print(styles)
if
len
(
styles
)
<
2
:
inject_index
=
self
.
n_latent
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
# print("a")
else
:
# print(len(styles))
latent
=
styles
[
0
]
# print("b", latent.shape)
else
:
# print("c")
if
inject_index
is
None
:
inject_index
=
4
latent
=
styles
[
0
].
unsqueeze
(
0
)
if
latent
.
shape
[
1
]
==
1
:
latent
=
latent
.
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
latent
[:,
:
inject_index
,
:]
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
n_latent
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
features
=
{}
out
=
self
.
input
(
latent
)
features
[
"out_0"
]
=
out
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
noise
[
0
])
features
[
"conv1_0"
]
=
out
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
features
[
"skip_0"
]
=
skip
i
=
1
for
conv1
,
conv2
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
noise
[
1
::
2
],
noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
conv1
(
out
,
latent
[:,
i
],
noise
=
noise1
)
features
[
"conv1_{}"
.
format
(
i
)]
=
out
out
=
conv2
(
out
,
latent
[:,
i
+
1
],
noise
=
noise2
)
features
[
"conv2_{}"
.
format
(
i
)]
=
out
skip
=
to_rgb
(
out
,
latent
[:,
i
+
2
],
skip
)
features
[
"skip_{}"
.
format
(
i
)]
=
skip
i
+=
2
image
=
skip
if
return_latents
:
return
image
,
latent
elif
return_features
:
return
image
,
features
else
:
return
image
,
None
class
ConvLayer
(
nn
.
Sequential
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
bias
=
True
,
activate
=
True
,
):
layers
=
[]
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
layers
.
append
(
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
self
.
padding
=
0
else
:
stride
=
1
self
.
padding
=
kernel_size
//
2
layers
.
append
(
EqualConv2d
(
in_channel
,
out_channel
,
kernel_size
,
padding
=
self
.
padding
,
stride
=
stride
,
bias
=
bias
and
not
activate
,
)
)
if
activate
:
if
bias
:
layers
.
append
(
FusedLeakyReLU
(
out_channel
))
else
:
layers
.
append
(
ScaledLeakyReLU
(
0.2
))
super
().
__init__
(
*
layers
)
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
self
.
conv1
=
ConvLayer
(
in_channel
,
in_channel
,
3
)
self
.
conv2
=
ConvLayer
(
in_channel
,
out_channel
,
3
,
downsample
=
True
)
self
.
skip
=
ConvLayer
(
in_channel
,
out_channel
,
1
,
downsample
=
True
,
activate
=
False
,
bias
=
False
)
def
forward
(
self
,
input
):
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
math
.
sqrt
(
2
)
return
out
class
StyleDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
small
=
False
):
super
().
__init__
()
if
small
:
channels
=
{
4
:
64
,
8
:
64
,
16
:
64
,
32
:
64
,
64
:
64
}
else
:
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
log_size
=
int
(
math
.
log
(
size
,
2
))
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
,
blur_kernel
))
in_channel
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
stddev_group
=
4
self
.
stddev_feat
=
1
self
.
final_conv
=
ConvLayer
(
in_channel
+
1
,
channels
[
4
],
3
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinear
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
activation
=
"fused_lrelu"
),
EqualLinear
(
channels
[
4
],
1
),
)
# def forward(self, input):
# out = self.convs(input)
# batch, channel, height, width = out.shape
# group = min(batch, self.stddev_group)
# stddev = out.view(
# group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
# )
# stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
# stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
# stddev = stddev.repeat(group, 1, height, width)
# out = torch.cat([out, stddev], 1)
# out = self.final_conv(out)
# out = out.view(batch, -1)
# out = self.final_linear(out)
# return out
def
forward
(
self
,
input
):
h
=
input
h_list
=
[]
for
index
,
blocklist
in
enumerate
(
self
.
convs
):
h
=
blocklist
(
h
)
h_list
.
append
(
h
)
out
=
h
batch
,
channel
,
height
,
width
=
out
.
shape
group
=
min
(
batch
,
self
.
stddev_group
)
stddev
=
out
.
view
(
group
,
-
1
,
self
.
stddev_feat
,
channel
//
self
.
stddev_feat
,
height
,
width
)
stddev
=
torch
.
sqrt
(
stddev
.
var
(
0
,
unbiased
=
False
)
+
1e-8
)
stddev
=
stddev
.
mean
([
2
,
3
,
4
],
keepdims
=
True
).
squeeze
(
2
)
stddev
=
stddev
.
repeat
(
group
,
1
,
height
,
width
)
out
=
torch
.
cat
([
out
,
stddev
],
1
)
out
=
self
.
final_conv
(
out
)
h_list
.
append
(
out
)
out
=
out
.
view
(
batch
,
-
1
)
out
=
self
.
final_linear
(
out
)
return
out
,
h_list
class
StyleEncoder
(
nn
.
Module
):
def
__init__
(
self
,
size
,
w_dim
=
512
):
super
().
__init__
()
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
}
self
.
w_dim
=
w_dim
log_size
=
int
(
math
.
log
(
size
,
2
))
# self.n_latents = log_size*2 - 2
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
))
in_channel
=
out_channel
# convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
convs
.
append
(
EqualConv2d
(
in_channel
,
2
*
self
.
w_dim
,
4
,
padding
=
0
,
bias
=
False
))
self
.
convs
=
nn
.
Sequential
(
*
convs
)
def
forward
(
self
,
input
):
out
=
self
.
convs
(
input
)
# return out.view(len(input), self.n_latents, self.w_dim)
reshaped
=
out
.
view
(
len
(
input
),
2
*
self
.
w_dim
)
return
reshaped
[:,:
self
.
w_dim
],
reshaped
[:,
self
.
w_dim
:]
def
kaiming_init
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Conv2d
)):
init
.
kaiming_normal_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
)):
m
.
weight
.
data
.
fill_
(
1
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
def
normal_init
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Conv2d
)):
init
.
normal_
(
m
.
weight
,
0
,
0.02
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
)):
m
.
weight
.
data
.
fill_
(
1
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0
)
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment