Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
NVComposer
Commits
30af93f2
Commit
30af93f2
authored
Dec 26, 2024
by
chenpangpang
Browse files
feat: gpu初始提交
parent
68e98ab8
Pipeline
#2159
canceled with stages
Changes
66
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6812 additions
and
0 deletions
+6812
-0
NVComposer/core/data/__init__.py
NVComposer/core/data/__init__.py
+0
-0
NVComposer/core/data/camera_pose_utils.py
NVComposer/core/data/camera_pose_utils.py
+277
-0
NVComposer/core/data/combined_multi_view_dataset.py
NVComposer/core/data/combined_multi_view_dataset.py
+341
-0
NVComposer/core/data/utils.py
NVComposer/core/data/utils.py
+184
-0
NVComposer/core/distributions.py
NVComposer/core/distributions.py
+102
-0
NVComposer/core/ema.py
NVComposer/core/ema.py
+84
-0
NVComposer/core/losses/__init__.py
NVComposer/core/losses/__init__.py
+1
-0
NVComposer/core/losses/contperceptual.py
NVComposer/core/losses/contperceptual.py
+173
-0
NVComposer/core/losses/vqperceptual.py
NVComposer/core/losses/vqperceptual.py
+217
-0
NVComposer/core/models/autoencoder.py
NVComposer/core/models/autoencoder.py
+395
-0
NVComposer/core/models/diffusion.py
NVComposer/core/models/diffusion.py
+1679
-0
NVComposer/core/models/samplers/__init__.py
NVComposer/core/models/samplers/__init__.py
+0
-0
NVComposer/core/models/samplers/ddim.py
NVComposer/core/models/samplers/ddim.py
+546
-0
NVComposer/core/models/samplers/dpm_solver/__init__.py
NVComposer/core/models/samplers/dpm_solver/__init__.py
+1
-0
NVComposer/core/models/samplers/dpm_solver/dpm_solver.py
NVComposer/core/models/samplers/dpm_solver/dpm_solver.py
+1298
-0
NVComposer/core/models/samplers/dpm_solver/sampler.py
NVComposer/core/models/samplers/dpm_solver/sampler.py
+91
-0
NVComposer/core/models/samplers/plms.py
NVComposer/core/models/samplers/plms.py
+358
-0
NVComposer/core/models/samplers/uni_pc/__init__.py
NVComposer/core/models/samplers/uni_pc/__init__.py
+0
-0
NVComposer/core/models/samplers/uni_pc/sampler.py
NVComposer/core/models/samplers/uni_pc/sampler.py
+67
-0
NVComposer/core/models/samplers/uni_pc/uni_pc.py
NVComposer/core/models/samplers/uni_pc/uni_pc.py
+998
-0
No files found.
NVComposer/core/data/__init__.py
0 → 100755
View file @
30af93f2
NVComposer/core/data/camera_pose_utils.py
0 → 100755
View file @
30af93f2
import
copy
import
numpy
as
np
import
torch
from
scipy.spatial.transform
import
Rotation
as
R
def
get_opencv_from_blender
(
matrix_world
,
fov
,
image_size
):
# convert matrix_world to opencv format extrinsics
opencv_world_to_cam
=
matrix_world
.
inverse
()
opencv_world_to_cam
[
1
,
:]
*=
-
1
opencv_world_to_cam
[
2
,
:]
*=
-
1
R
,
T
=
opencv_world_to_cam
[:
3
,
:
3
],
opencv_world_to_cam
[:
3
,
3
]
R
,
T
=
R
.
unsqueeze
(
0
),
T
.
unsqueeze
(
0
)
# convert fov to opencv format intrinsics
focal
=
1
/
np
.
tan
(
fov
/
2
)
intrinsics
=
np
.
diag
(
np
.
array
([
focal
,
focal
,
1
])).
astype
(
np
.
float32
)
opencv_cam_matrix
=
torch
.
from_numpy
(
intrinsics
).
unsqueeze
(
0
).
float
()
opencv_cam_matrix
[:,
:
2
,
-
1
]
+=
torch
.
tensor
([
image_size
/
2
,
image_size
/
2
])
opencv_cam_matrix
[:,
[
0
,
1
],
[
0
,
1
]]
*=
image_size
/
2
return
R
,
T
,
opencv_cam_matrix
def
cartesian_to_spherical
(
xyz
):
xy
=
xyz
[:,
0
]
**
2
+
xyz
[:,
1
]
**
2
z
=
np
.
sqrt
(
xy
+
xyz
[:,
2
]
**
2
)
# for elevation angle defined from z-axis down
theta
=
np
.
arctan2
(
np
.
sqrt
(
xy
),
xyz
[:,
2
])
azimuth
=
np
.
arctan2
(
xyz
[:,
1
],
xyz
[:,
0
])
return
np
.
stack
([
theta
,
azimuth
,
z
],
axis
=-
1
)
def
spherical_to_cartesian
(
spherical_coords
):
# convert from spherical to cartesian coordinates
theta
,
azimuth
,
radius
=
spherical_coords
.
T
x
=
radius
*
np
.
sin
(
theta
)
*
np
.
cos
(
azimuth
)
y
=
radius
*
np
.
sin
(
theta
)
*
np
.
sin
(
azimuth
)
z
=
radius
*
np
.
cos
(
theta
)
return
np
.
stack
([
x
,
y
,
z
],
axis
=-
1
)
def
look_at
(
eye
,
center
,
up
):
# Create a normalized direction vector from eye to center
f
=
np
.
array
(
center
)
-
np
.
array
(
eye
)
f
/=
np
.
linalg
.
norm
(
f
)
# Create a normalized right vector
up_norm
=
np
.
array
(
up
)
/
np
.
linalg
.
norm
(
up
)
s
=
np
.
cross
(
f
,
up_norm
)
s
/=
np
.
linalg
.
norm
(
s
)
# Recompute the up vector
u
=
np
.
cross
(
s
,
f
)
# Create rotation matrix R
R
=
np
.
array
([[
s
[
0
],
s
[
1
],
s
[
2
]],
[
u
[
0
],
u
[
1
],
u
[
2
]],
[
-
f
[
0
],
-
f
[
1
],
-
f
[
2
]]])
# Create translation vector T
T
=
-
np
.
dot
(
R
,
np
.
array
(
eye
))
return
R
,
T
def
get_blender_from_spherical
(
elevation
,
azimuth
):
"""Generates blender camera from spherical coordinates."""
cartesian_coords
=
spherical_to_cartesian
(
np
.
array
([[
elevation
,
azimuth
,
3.5
]]))
# get camera rotation
center
=
np
.
array
([
0
,
0
,
0
])
eye
=
cartesian_coords
[
0
]
up
=
np
.
array
([
0
,
0
,
1
])
R
,
T
=
look_at
(
eye
,
center
,
up
)
R
=
R
.
T
T
=
-
np
.
dot
(
R
,
T
)
RT
=
np
.
concatenate
([
R
,
T
.
reshape
(
3
,
1
)],
axis
=-
1
)
blender_cam
=
torch
.
from_numpy
(
RT
).
float
()
blender_cam
=
torch
.
cat
([
blender_cam
,
torch
.
tensor
([[
0
,
0
,
0
,
1
]])],
dim
=
0
)
print
(
blender_cam
)
return
blender_cam
def
invert_pose
(
r
,
t
):
r_inv
=
r
.
T
t_inv
=
-
np
.
dot
(
r_inv
,
t
)
return
r_inv
,
t_inv
def
transform_pose_sequence_to_relative
(
poses
,
as_z_up
=
False
):
"""
poses: a sequence of 3*4 C2W camera pose matrices
as_z_up: output in z-up format. If False, the output is in y-up format
"""
r0
,
t0
=
poses
[
0
][:
3
,
:
3
],
poses
[
0
][:
3
,
3
]
# r0_inv, t0_inv = invert_pose(r0, t0)
r0_inv
=
r0
.
T
new_rt0
=
np
.
hstack
([
np
.
eye
(
3
,
3
),
np
.
zeros
((
3
,
1
))])
if
as_z_up
:
new_rt0
=
c2w_y_up_to_z_up
(
new_rt0
)
transformed_poses
=
[
new_rt0
]
for
pose
in
poses
[
1
:]:
r
,
t
=
pose
[:
3
,
:
3
],
pose
[:
3
,
3
]
new_r
=
np
.
dot
(
r0_inv
,
r
)
new_t
=
np
.
dot
(
r0_inv
,
t
-
t0
)
new_rt
=
np
.
hstack
([
new_r
,
new_t
[:,
None
]])
if
as_z_up
:
new_rt
=
c2w_y_up_to_z_up
(
new_rt
)
transformed_poses
.
append
(
new_rt
)
return
transformed_poses
def
c2w_y_up_to_z_up
(
c2w_3x4
):
R_y_up_to_z_up
=
np
.
array
([[
1
,
0
,
0
],
[
0
,
0
,
-
1
],
[
0
,
1
,
0
]])
R
=
c2w_3x4
[:,
:
3
]
t
=
c2w_3x4
[:,
3
]
R_z_up
=
R_y_up_to_z_up
@
R
t_z_up
=
R_y_up_to_z_up
@
t
T_z_up
=
np
.
hstack
((
R_z_up
,
t_z_up
.
reshape
(
3
,
1
)))
return
T_z_up
def
transform_pose_sequence_to_relative_w2c
(
poses
):
new_rt_list
=
[]
first_frame_rt
=
copy
.
deepcopy
(
poses
[
0
])
first_frame_r_inv
=
first_frame_rt
[:,
:
3
].
T
first_frame_t
=
first_frame_rt
[:,
-
1
]
for
rt
in
poses
:
rt
[:,
:
3
]
=
np
.
matmul
(
rt
[:,
:
3
],
first_frame_r_inv
)
rt
[:,
-
1
]
=
rt
[:,
-
1
]
-
np
.
matmul
(
rt
[:,
:
3
],
first_frame_t
)
new_rt_list
.
append
(
copy
.
deepcopy
(
rt
))
return
new_rt_list
def
transform_pose_sequence_to_relative_c2w
(
poses
):
first_frame_rt
=
poses
[
0
]
first_frame_r_inv
=
first_frame_rt
[:,
:
3
].
T
first_frame_t
=
first_frame_rt
[:,
-
1
]
rotations
=
poses
[:,
:,
:
3
]
translations
=
poses
[:,
:,
3
]
# Compute new rotations and translations in batch
new_rotations
=
torch
.
matmul
(
first_frame_r_inv
,
rotations
)
new_translations
=
torch
.
matmul
(
first_frame_r_inv
,
(
translations
-
first_frame_t
.
unsqueeze
(
0
)).
unsqueeze
(
-
1
)
)
# Concatenate new rotations and translations
new_rt
=
torch
.
cat
([
new_rotations
,
new_translations
],
dim
=-
1
)
return
new_rt
def
convert_w2c_between_c2w
(
poses
):
rotations
=
poses
[:,
:,
:
3
]
translations
=
poses
[:,
:,
3
]
new_rotations
=
rotations
.
transpose
(
-
1
,
-
2
)
new_translations
=
torch
.
matmul
(
-
new_rotations
,
translations
.
unsqueeze
(
-
1
))
new_rt
=
torch
.
cat
([
new_rotations
,
new_translations
],
dim
=-
1
)
return
new_rt
def
slerp
(
q1
,
q2
,
t
):
"""
Performs spherical linear interpolation (SLERP) between two quaternions.
Args:
q1 (torch.Tensor): Start quaternion (4,).
q2 (torch.Tensor): End quaternion (4,).
t (float or torch.Tensor): Interpolation parameter in [0, 1].
Returns:
torch.Tensor: Interpolated quaternion (4,).
"""
q1
=
q1
/
torch
.
linalg
.
norm
(
q1
)
# Normalize q1
q2
=
q2
/
torch
.
linalg
.
norm
(
q2
)
# Normalize q2
dot
=
torch
.
dot
(
q1
,
q2
)
# Ensure shortest path (flip q2 if needed)
if
dot
<
0.0
:
q2
=
-
q2
dot
=
-
dot
# Avoid numerical precision issues
dot
=
torch
.
clamp
(
dot
,
-
1.0
,
1.0
)
theta
=
torch
.
acos
(
dot
)
# Angle between q1 and q2
if
theta
<
1e-6
:
# If very close, use linear interpolation
return
(
1
-
t
)
*
q1
+
t
*
q2
sin_theta
=
torch
.
sin
(
theta
)
return
(
torch
.
sin
((
1
-
t
)
*
theta
)
/
sin_theta
)
*
q1
+
(
torch
.
sin
(
t
*
theta
)
/
sin_theta
)
*
q2
def
interpolate_camera_poses
(
c2w
:
torch
.
Tensor
,
factor
:
int
)
->
torch
.
Tensor
:
"""
Interpolates a sequence of camera c2w poses to N times the length of the original sequence.
Args:
c2w (torch.Tensor): Input camera poses of shape (N, 3, 4).
factor (int): The upsampling factor (e.g., 2 for doubling the length).
Returns:
torch.Tensor: Interpolated camera poses of shape (N * factor, 3, 4).
"""
assert
c2w
.
ndim
==
3
and
c2w
.
shape
[
1
:]
==
(
3
,
4
,
),
"Input tensor must have shape (N, 3, 4)."
assert
factor
>
1
,
"Upsampling factor must be greater than 1."
N
=
c2w
.
shape
[
0
]
new_length
=
N
*
factor
# Extract rotations (R) and translations (T)
rotations
=
c2w
[:,
:
3
,
:
3
]
# Shape (N, 3, 3)
translations
=
c2w
[:,
:
3
,
3
]
# Shape (N, 3)
# Convert rotations to quaternions for interpolation
quaternions
=
torch
.
tensor
(
R
.
from_matrix
(
rotations
.
numpy
()).
as_quat
()
)
# Shape (N, 4)
# Initialize interpolated quaternions and translations
interpolated_quats
=
[]
interpolated_translations
=
[]
# Perform interpolation
for
i
in
range
(
N
-
1
):
# Start and end quaternions and translations for this segment
q1
,
q2
=
quaternions
[
i
],
quaternions
[
i
+
1
]
t1
,
t2
=
translations
[
i
],
translations
[
i
+
1
]
# Time steps for interpolation within this segment
t_values
=
torch
.
linspace
(
0
,
1
,
factor
,
dtype
=
torch
.
float32
)
# Interpolate quaternions using SLERP
for
t
in
t_values
:
interpolated_quats
.
append
(
slerp
(
q1
,
q2
,
t
))
# Interpolate translations linearly
interp_t
=
t1
*
(
1
-
t_values
[:,
None
])
+
t2
*
t_values
[:,
None
]
interpolated_translations
.
append
(
interp_t
)
interpolated_quats
.
append
(
quaternions
[
0
])
interpolated_translations
.
append
(
translations
[
0
].
unsqueeze
(
0
))
# Add the last pose (end of sequence)
interpolated_quats
.
append
(
quaternions
[
-
1
])
interpolated_translations
.
append
(
translations
[
-
1
].
unsqueeze
(
0
))
# Add as 2D tensor
# Combine interpolated results
interpolated_quats
=
torch
.
stack
(
interpolated_quats
,
dim
=
0
)
# Shape (new_length, 4)
interpolated_translations
=
torch
.
cat
(
interpolated_translations
,
dim
=
0
)
# Shape (new_length, 3)
# Convert quaternions back to rotation matrices
interpolated_rotations
=
torch
.
tensor
(
R
.
from_quat
(
interpolated_quats
.
numpy
()).
as_matrix
()
)
# Shape (new_length, 3, 3)
# Form final c2w matrix
interpolated_c2w
=
torch
.
zeros
((
new_length
,
3
,
4
),
dtype
=
torch
.
float32
)
interpolated_c2w
[:,
:
3
,
:
3
]
=
interpolated_rotations
interpolated_c2w
[:,
:
3
,
3
]
=
interpolated_translations
return
interpolated_c2w
NVComposer/core/data/combined_multi_view_dataset.py
0 → 100755
View file @
30af93f2
import
PIL
import
numpy
as
np
import
torch
from
PIL
import
Image
from
.camera_pose_utils
import
(
convert_w2c_between_c2w
,
transform_pose_sequence_to_relative_c2w
,
)
def
get_ray_embeddings
(
poses
,
size_h
=
256
,
size_w
=
256
,
fov_xy_list
=
None
,
focal_xy_list
=
None
):
"""
poses: sequence of cameras poses (y-up format)
"""
use_focal
=
False
if
fov_xy_list
is
None
or
fov_xy_list
[
0
]
is
None
or
fov_xy_list
[
0
][
0
]
is
None
:
assert
focal_xy_list
is
not
None
use_focal
=
True
rays_embeddings
=
[]
for
i
in
range
(
poses
.
shape
[
0
]):
cur_pose
=
poses
[
i
]
if
use_focal
:
rays_o
,
rays_d
=
get_rays
(
# [h, w, 3]
cur_pose
,
size_h
,
size_w
,
focal_xy
=
focal_xy_list
[
i
],
)
else
:
rays_o
,
rays_d
=
get_rays
(
cur_pose
,
size_h
,
size_w
,
fov_xy
=
fov_xy_list
[
i
]
)
# [h, w, 3]
rays_plucker
=
torch
.
cat
(
[
torch
.
cross
(
rays_o
,
rays_d
,
dim
=-
1
),
rays_d
],
dim
=-
1
)
# [h, w, 6]
rays_embeddings
.
append
(
rays_plucker
)
rays_embeddings
=
(
torch
.
stack
(
rays_embeddings
,
dim
=
0
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
)
# [V, 6, h, w]
return
rays_embeddings
def
get_rays
(
pose
,
h
,
w
,
fov_xy
=
None
,
focal_xy
=
None
,
opengl
=
True
):
x
,
y
=
torch
.
meshgrid
(
torch
.
arange
(
w
,
device
=
pose
.
device
),
torch
.
arange
(
h
,
device
=
pose
.
device
),
indexing
=
"xy"
,
)
x
=
x
.
flatten
()
y
=
y
.
flatten
()
cx
=
w
*
0.5
cy
=
h
*
0.5
# print("fov_xy=", fov_xy)
# print("focal_xy=", focal_xy)
if
focal_xy
is
None
:
assert
fov_xy
is
not
None
,
"fov_x/y and focal_x/y cannot both be None."
focal_x
=
w
*
0.5
/
np
.
tan
(
0.5
*
np
.
deg2rad
(
fov_xy
[
0
]))
focal_y
=
h
*
0.5
/
np
.
tan
(
0.5
*
np
.
deg2rad
(
fov_xy
[
1
]))
else
:
assert
(
len
(
focal_xy
)
==
2
),
"focal_xy should be a list-like object containing only two elements (focal length in x and y direction)."
focal_x
=
w
*
focal_xy
[
0
]
focal_y
=
h
*
focal_xy
[
1
]
camera_dirs
=
torch
.
nn
.
functional
.
pad
(
torch
.
stack
(
[
(
x
-
cx
+
0.5
)
/
focal_x
,
(
y
-
cy
+
0.5
)
/
focal_y
*
(
-
1.0
if
opengl
else
1.0
),
],
dim
=-
1
,
),
(
0
,
1
),
value
=
(
-
1.0
if
opengl
else
1.0
),
)
# [hw, 3]
rays_d
=
camera_dirs
@
pose
[:
3
,
:
3
].
transpose
(
0
,
1
)
# [hw, 3]
rays_o
=
pose
[:
3
,
3
].
unsqueeze
(
0
).
expand_as
(
rays_d
)
# [hw, 3]
rays_o
=
rays_o
.
view
(
h
,
w
,
3
)
rays_d
=
safe_normalize
(
rays_d
).
view
(
h
,
w
,
3
)
return
rays_o
,
rays_d
def
safe_normalize
(
x
,
eps
=
1e-20
):
return
x
/
length
(
x
,
eps
)
def
length
(
x
,
eps
=
1e-20
):
if
isinstance
(
x
,
np
.
ndarray
):
return
np
.
sqrt
(
np
.
maximum
(
np
.
sum
(
x
*
x
,
axis
=-
1
,
keepdims
=
True
),
eps
))
else
:
return
torch
.
sqrt
(
torch
.
clamp
(
dot
(
x
,
x
),
min
=
eps
))
def
dot
(
x
,
y
):
if
isinstance
(
x
,
np
.
ndarray
):
return
np
.
sum
(
x
*
y
,
-
1
,
keepdims
=
True
)
else
:
return
torch
.
sum
(
x
*
y
,
-
1
,
keepdim
=
True
)
def
extend_list_by_repeating
(
original_list
,
target_length
,
repeat_idx
,
at_front
):
if
not
original_list
:
raise
ValueError
(
"The original list cannot be empty."
)
extended_list
=
[]
original_length
=
len
(
original_list
)
for
i
in
range
(
target_length
-
original_length
):
extended_list
.
append
(
original_list
[
repeat_idx
])
if
at_front
:
extended_list
.
extend
(
original_list
)
return
extended_list
else
:
original_list
.
extend
(
extended_list
)
return
original_list
def
select_evenly_spaced_elements
(
arr
,
x
):
if
x
<=
0
or
len
(
arr
)
==
0
:
return
[]
# Calculate step size as the ratio of length of the list and x
step
=
len
(
arr
)
/
x
# Pick elements at indices that are multiples of step (round them to nearest integer)
selected_elements
=
[
arr
[
round
(
i
*
step
)]
for
i
in
range
(
x
)]
return
selected_elements
def
convert_co3d_annotation_to_opengl_pose_and_intrinsics
(
frame_annotation
):
p
=
frame_annotation
.
viewpoint
.
principal_point
f
=
frame_annotation
.
viewpoint
.
focal_length
h
,
w
=
frame_annotation
.
image
.
size
K
=
np
.
eye
(
3
)
s
=
(
min
(
h
,
w
)
-
1
)
/
2
if
frame_annotation
.
viewpoint
.
intrinsics_format
==
"ndc_norm_image_bounds"
:
K
[
0
,
0
]
=
f
[
0
]
*
(
w
-
1
)
/
2
K
[
1
,
1
]
=
f
[
1
]
*
(
h
-
1
)
/
2
elif
frame_annotation
.
viewpoint
.
intrinsics_format
==
"ndc_isotropic"
:
K
[
0
,
0
]
=
f
[
0
]
*
s
/
2
K
[
1
,
1
]
=
f
[
1
]
*
s
/
2
else
:
assert
(
False
),
f
"Invalid intrinsics_format:
{
frame_annotation
.
viewpoint
.
intrinsics_format
}
"
K
[
0
,
2
]
=
-
p
[
0
]
*
s
+
(
w
-
1
)
/
2
K
[
1
,
2
]
=
-
p
[
1
]
*
s
+
(
h
-
1
)
/
2
R
=
np
.
array
(
frame_annotation
.
viewpoint
.
R
).
T
# note the transpose here
T
=
np
.
array
(
frame_annotation
.
viewpoint
.
T
)
pose
=
np
.
concatenate
([
R
,
T
[:,
None
]],
1
)
# Need to be converted into OpenGL format. Flip the direction of x, z axis
pose
=
np
.
diag
([
-
1
,
1
,
-
1
]).
astype
(
np
.
float32
)
@
pose
return
pose
,
K
def
normalize_w2c_camera_pose_sequence
(
target_camera_poses
,
condition_camera_poses
=
None
,
output_c2w
=
False
,
translation_norm_mode
=
"div_by_max"
,
):
"""
Normalize camera pose sequence so that the first frame is identity rotation and zero translation,
and the translation scale is normalized by the farest point from the first frame (to one).
:param target_camera_poses: W2C poses tensor in [N, 3, 4]
:param condition_camera_poses: W2C poses tensor in [N, 3, 4]
:return: Tuple(Tensor, Tensor), the normalized `target_camera_poses` and `condition_camera_poses`
"""
# Normalize at w2c, all poses should be in w2c in UnifiedFrame
num_target_views
=
target_camera_poses
.
size
(
0
)
if
condition_camera_poses
is
not
None
:
all_poses
=
torch
.
concat
([
target_camera_poses
,
condition_camera_poses
],
dim
=
0
)
else
:
all_poses
=
target_camera_poses
# Convert W2C to C2W
normalized_poses
=
transform_pose_sequence_to_relative_c2w
(
convert_w2c_between_c2w
(
all_poses
)
)
# Here normalized_poses is C2W
if
not
output_c2w
:
# Convert from C2W back to W2C if output_c2w is False.
normalized_poses
=
convert_w2c_between_c2w
(
normalized_poses
)
t_norms
=
torch
.
linalg
.
norm
(
normalized_poses
[:,
:,
3
],
ord
=
2
,
dim
=-
1
)
# print("t_norms=", t_norms)
largest_t_norm
=
torch
.
max
(
t_norms
)
# print("largest_t_norm=", largest_t_norm)
# normalized_poses[:, :, 3] -= first_t.unsqueeze(0).repeat(normalized_poses.size(0), 1)
if
translation_norm_mode
==
"div_by_max_plus_one"
:
# Always add a constant component to the translation norm
largest_t_norm
=
largest_t_norm
+
1.0
elif
translation_norm_mode
==
"div_by_max"
:
largest_t_norm
=
largest_t_norm
if
largest_t_norm
<=
0.05
:
largest_t_norm
=
0.05
elif
translation_norm_mode
==
"disabled"
:
largest_t_norm
=
1.0
else
:
assert
False
,
f
"Invalid translation_norm_mode:
{
translation_norm_mode
}
."
normalized_poses
[:,
:,
3
]
/=
largest_t_norm
target_camera_poses
=
normalized_poses
[:
num_target_views
]
if
condition_camera_poses
is
not
None
:
condition_camera_poses
=
normalized_poses
[
num_target_views
:]
else
:
condition_camera_poses
=
None
# print("After First condition:", condition_camera_poses[0])
# print("After First target:", target_camera_poses[0])
return
target_camera_poses
,
condition_camera_poses
def
central_crop_pil_image
(
_image
,
crop_size
,
use_central_padding
=
False
):
if
use_central_padding
:
# Determine the new size
_w
,
_h
=
_image
.
size
new_size
=
max
(
_w
,
_h
)
# Create a new image with white background
new_image
=
Image
.
new
(
"RGB"
,
(
new_size
,
new_size
),
(
255
,
255
,
255
))
# Calculate the position to paste the original image
paste_position
=
((
new_size
-
_w
)
//
2
,
(
new_size
-
_h
)
//
2
)
# Paste the original image onto the new image
new_image
.
paste
(
_image
,
paste_position
)
_image
=
new_image
# get the new size again if padded
_w
,
_h
=
_image
.
size
scale
=
crop_size
/
min
(
_h
,
_w
)
# resize shortest side to crop_size
_w_out
,
_h_out
=
int
(
scale
*
_w
),
int
(
scale
*
_h
)
_image
=
_image
.
resize
(
(
_w_out
,
_h_out
),
resample
=
(
PIL
.
Image
.
Resampling
.
LANCZOS
if
scale
<
1
else
PIL
.
Image
.
Resampling
.
BICUBIC
),
)
# center crop
margin_w
=
(
_image
.
size
[
0
]
-
crop_size
)
//
2
margin_h
=
(
_image
.
size
[
1
]
-
crop_size
)
//
2
_image
=
_image
.
crop
(
(
margin_w
,
margin_h
,
margin_w
+
crop_size
,
margin_h
+
crop_size
)
)
return
_image
def
crop_and_resize
(
image
:
Image
.
Image
,
target_width
:
int
,
target_height
:
int
)
->
Image
.
Image
:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width
,
original_height
=
image
.
size
original_aspect
=
original_width
/
original_height
target_aspect
=
target_width
/
target_height
# Calculate crop box to maintain aspect ratio
if
original_aspect
>
target_aspect
:
# Crop horizontally
new_width
=
int
(
original_height
*
target_aspect
)
new_height
=
original_height
left
=
(
original_width
-
new_width
)
/
2
top
=
0
right
=
left
+
new_width
bottom
=
original_height
else
:
# Crop vertically
new_width
=
original_width
new_height
=
int
(
original_width
/
target_aspect
)
left
=
0
top
=
(
original_height
-
new_height
)
/
2
right
=
original_width
bottom
=
top
+
new_height
# Crop and resize
cropped_image
=
image
.
crop
((
left
,
top
,
right
,
bottom
))
resized_image
=
cropped_image
.
resize
((
target_width
,
target_height
),
Image
.
LANCZOS
)
return
resized_image
def
calculate_fov_after_resize
(
fov_x
:
float
,
fov_y
:
float
,
original_width
:
int
,
original_height
:
int
,
target_width
:
int
,
target_height
:
int
,
)
->
(
float
,
float
):
"""
Calculates the new field of view after cropping and resizing an image.
Args:
fov_x (float): Original field of view in the x-direction (horizontal).
fov_y (float): Original field of view in the y-direction (vertical).
original_width (int): Original width of the image.
original_height (int): Original height of the image.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
(float, float): New field of view (fov_x, fov_y) after cropping and resizing.
"""
original_aspect
=
original_width
/
original_height
target_aspect
=
target_width
/
target_height
if
original_aspect
>
target_aspect
:
# Crop horizontally
new_width
=
int
(
original_height
*
target_aspect
)
new_fov_x
=
fov_x
*
(
new_width
/
original_width
)
new_fov_y
=
fov_y
else
:
# Crop vertically
new_height
=
int
(
original_width
/
target_aspect
)
new_fov_y
=
fov_y
*
(
new_height
/
original_height
)
new_fov_x
=
fov_x
return
new_fov_x
,
new_fov_y
NVComposer/core/data/utils.py
0 → 100755
View file @
30af93f2
import
copy
import
random
from
PIL
import
Image
import
numpy
as
np
def
create_relative
(
RT_list
,
K_1
=
4.7
,
dataset
=
"syn"
):
if
dataset
==
"realestate"
:
scale_T
=
1
RT_list
=
[
RT
.
reshape
(
3
,
4
)
for
RT
in
RT_list
]
elif
dataset
==
"syn"
:
scale_T
=
(
470
/
K_1
)
/
7.5
"""
4.694746736956946052e+02 0.000000000000000000e+00 4.800000000000000000e+02
0.000000000000000000e+00 4.694746736956946052e+02 2.700000000000000000e+02
0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00
"""
elif
dataset
==
"zero123"
:
scale_T
=
0.5
else
:
raise
Exception
(
"invalid dataset type"
)
# convert x y z to x -y -z
if
dataset
==
"zero123"
:
flip_matrix
=
np
.
array
([[
1
,
0
,
0
],
[
0
,
-
1
,
0
],
[
0
,
0
,
-
1
]])
for
i
in
range
(
len
(
RT_list
)):
RT_list
[
i
]
=
np
.
dot
(
flip_matrix
,
RT_list
[
i
])
temp
=
[]
first_frame_RT
=
copy
.
deepcopy
(
RT_list
[
0
])
# first_frame_R_inv = np.linalg.inv(first_frame_RT[:,:3])
first_frame_R_inv
=
first_frame_RT
[:,
:
3
].
T
first_frame_T
=
first_frame_RT
[:,
-
1
]
for
RT
in
RT_list
:
RT
[:,
:
3
]
=
np
.
dot
(
RT
[:,
:
3
],
first_frame_R_inv
)
RT
[:,
-
1
]
=
RT
[:,
-
1
]
-
np
.
dot
(
RT
[:,
:
3
],
first_frame_T
)
RT
[:,
-
1
]
=
RT
[:,
-
1
]
*
scale_T
temp
.
append
(
RT
)
RT_list
=
temp
if
dataset
==
"realestate"
:
RT_list
=
[
RT
.
reshape
(
-
1
)
for
RT
in
RT_list
]
return
RT_list
def
sigma_matrix2
(
sig_x
,
sig_y
,
theta
):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_y
**
2
]])
u_matrix
=
np
.
array
(
[[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]
)
return
np
.
dot
(
u_matrix
,
np
.
dot
(
d_matrix
,
u_matrix
.
T
))
def
mesh_grid
(
kernel_size
):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax
=
np
.
arange
(
-
kernel_size
//
2
+
1.0
,
kernel_size
//
2
+
1.0
)
xx
,
yy
=
np
.
meshgrid
(
ax
,
ax
)
xy
=
np
.
hstack
(
(
xx
.
reshape
((
kernel_size
*
kernel_size
,
1
)),
yy
.
reshape
(
kernel_size
*
kernel_size
,
1
),
)
).
reshape
(
kernel_size
,
kernel_size
,
2
)
return
xy
,
xx
,
yy
def
pdf2
(
sigma_matrix
,
grid
):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma
=
np
.
linalg
.
inv
(
sigma_matrix
)
kernel
=
np
.
exp
(
-
0.5
*
np
.
sum
(
np
.
dot
(
grid
,
inverse_sigma
)
*
grid
,
2
))
return
kernel
def
bivariate_Gaussian
(
kernel_size
,
sig_x
,
sig_y
,
theta
,
grid
=
None
,
isotropic
=
True
):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if
grid
is
None
:
grid
,
_
,
_
=
mesh_grid
(
kernel_size
)
if
isotropic
:
sigma_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_x
**
2
]])
else
:
sigma_matrix
=
sigma_matrix2
(
sig_x
,
sig_y
,
theta
)
kernel
=
pdf2
(
sigma_matrix
,
grid
)
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
def
rgba_to_rgb_with_bg
(
rgba_image
,
bg_color
=
(
255
,
255
,
255
)):
"""
Convert a PIL RGBA Image to an RGB Image with a white background.
Args:
rgba_image (Image): A PIL Image object in RGBA mode.
Returns:
Image: A PIL Image object in RGB mode with white background.
"""
# Ensure the image is in RGBA mode
# Ensure the image is in RGBA mode
if
rgba_image
.
mode
!=
"RGBA"
:
return
rgba_image
# raise ValueError("The image must be in RGBA mode")
# Create a white background image
white_bg_rgb
=
Image
.
new
(
"RGB"
,
rgba_image
.
size
,
bg_color
)
# Paste the RGBA image onto the white background using alpha channel as mask
white_bg_rgb
.
paste
(
rgba_image
,
mask
=
rgba_image
.
split
()[
3
]
)
# 3 is the alpha channel index
return
white_bg_rgb
def
random_order_preserving_selection
(
items
,
num
):
if
num
>
len
(
items
):
print
(
"WARNING: Item list is shorter than `num` given."
)
return
items
selected_indices
=
sorted
(
random
.
sample
(
range
(
len
(
items
)),
num
))
selected_items
=
[
items
[
i
]
for
i
in
selected_indices
]
return
selected_items
def
pad_pil_image_to_square
(
image
,
fill_color
=
(
255
,
255
,
255
)):
"""
Pad an image to make it square with the given fill color.
Args:
image (PIL.Image): The original image.
fill_color (tuple): The color to use for padding (default is black).
Returns:
PIL.Image: A new image that is padded to be square.
"""
width
,
height
=
image
.
size
# Determine the new size, which will be the maximum of width or height
new_size
=
max
(
width
,
height
)
# Create a new image with the new size and fill color
new_image
=
Image
.
new
(
"RGB"
,
(
new_size
,
new_size
),
fill_color
)
# Calculate the position to paste the original image onto the new image
# This calculation centers the original image in the new square canvas
left
=
(
new_size
-
width
)
//
2
top
=
(
new_size
-
height
)
//
2
# Paste the original image into the new image
new_image
.
paste
(
image
,
(
left
,
top
))
return
new_image
NVComposer/core/distributions.py
0 → 100755
View file @
30af93f2
import
torch
import
numpy
as
np
class
AbstractDistribution
:
def
sample
(
self
):
raise
NotImplementedError
()
def
mode
(
self
):
raise
NotImplementedError
()
class
DiracDistribution
(
AbstractDistribution
):
def
__init__
(
self
,
value
):
self
.
value
=
value
def
sample
(
self
):
return
self
.
value
def
mode
(
self
):
return
self
.
value
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
,
noise
=
None
):
if
noise
is
None
:
noise
=
torch
.
randn
(
self
.
mean
.
shape
)
x
=
self
.
mean
+
self
.
std
*
noise
.
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
def
mode
(
self
):
return
self
.
mean
def
normal_kl
(
mean1
,
logvar1
,
mean2
,
logvar2
):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor
=
None
for
obj
in
(
mean1
,
logvar1
,
mean2
,
logvar2
):
if
isinstance
(
obj
,
torch
.
Tensor
):
tensor
=
obj
break
assert
tensor
is
not
None
,
"at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1
,
logvar2
=
[
x
if
isinstance
(
x
,
torch
.
Tensor
)
else
torch
.
tensor
(
x
).
to
(
tensor
)
for
x
in
(
logvar1
,
logvar2
)
]
return
0.5
*
(
-
1.0
+
logvar2
-
logvar1
+
torch
.
exp
(
logvar1
-
logvar2
)
+
((
mean1
-
mean2
)
**
2
)
*
torch
.
exp
(
-
logvar2
)
)
NVComposer/core/ema.py
0 → 100755
View file @
30af93f2
import
torch
from
torch
import
nn
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
"Decay must be between 0 and 1"
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
"decay"
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"num_updates"
,
(
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
)
),
)
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
"."
,
""
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
])
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
NVComposer/core/losses/__init__.py
0 → 100755
View file @
30af93f2
from
core.losses.contperceptual
import
LPIPSWithDiscriminator
\ No newline at end of file
NVComposer/core/losses/contperceptual.py
0 → 100755
View file @
30af93f2
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
taming.modules.losses.vqperceptual
import
*
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
,
max_bs
=
None
,
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
max_bs
=
max_bs
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
,
):
if
inputs
.
dim
()
==
5
:
inputs
=
rearrange
(
inputs
,
"b c t h w -> (b t) c h w"
)
if
reconstructions
.
dim
()
==
5
:
reconstructions
=
rearrange
(
reconstructions
,
"b c t h w -> (b t) c h w"
)
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
if
self
.
max_bs
is
not
None
and
self
.
max_bs
<
inputs
.
shape
[
0
]:
input_list
=
torch
.
split
(
inputs
,
self
.
max_bs
,
dim
=
0
)
reconstruction_list
=
torch
.
split
(
reconstructions
,
self
.
max_bs
,
dim
=
0
)
p_losses
=
[
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
()
)
for
inputs
,
reconstructions
in
zip
(
input_list
,
reconstruction_list
)
]
p_loss
=
torch
.
cat
(
p_losses
,
dim
=
0
)
else
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
()
)
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
)
)
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
(
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
)
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
)
)
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
)
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
(),
}
return
d_loss
,
log
NVComposer/core/losses/vqperceptual.py
0 → 100755
View file @
30af93f2
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
NLayerDiscriminator
,
weights_init
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.0
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.0
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.0
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
class
VQLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
codebook_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_ndf
=
64
,
disc_loss
=
"hinge"
,
n_classes
=
None
,
perceptual_loss
=
"lpips"
,
pixel_loss
=
"l1"
,
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
assert
perceptual_loss
in
[
"lpips"
,
"clips"
,
"dists"
]
assert
pixel_loss
in
[
"l1"
,
"l2"
]
self
.
codebook_weight
=
codebook_weight
self
.
pixel_weight
=
pixelloss_weight
if
perceptual_loss
==
"lpips"
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Running with LPIPS."
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
else
:
raise
ValueError
(
f
"Unknown perceptual loss: >>
{
perceptual_loss
}
<<"
)
self
.
perceptual_weight
=
perceptual_weight
if
pixel_loss
==
"l1"
:
self
.
pixel_loss
=
l1
else
:
self
.
pixel_loss
=
l2
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
,
ndf
=
disc_ndf
,
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
if
disc_loss
==
"hinge"
:
self
.
disc_loss
=
hinge_d_loss
elif
disc_loss
==
"vanilla"
:
self
.
disc_loss
=
vanilla_d_loss
else
:
raise
ValueError
(
f
"Unknown GAN loss '
{
disc_loss
}
'."
)
print
(
f
"VQLPIPSWithDiscriminator running with
{
disc_loss
}
loss."
)
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
n_classes
=
n_classes
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
codebook_loss
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
predicted_indices
=
None
,
):
if
not
exists
(
codebook_loss
):
codebook_loss
=
torch
.
tensor
([
0.0
]).
to
(
inputs
.
device
)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss
=
self
.
pixel_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
()
)
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
else
:
p_loss
=
torch
.
tensor
([
0.0
])
nll_loss
=
rec_loss
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss
=
torch
.
mean
(
nll_loss
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
)
)
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
(
nll_loss
+
d_weight
*
disc_factor
*
g_loss
+
self
.
codebook_weight
*
codebook_loss
.
mean
()
)
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/quant_loss"
.
format
(
split
):
codebook_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/p_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
if
predicted_indices
is
not
None
:
assert
self
.
n_classes
is
not
None
with
torch
.
no_grad
():
perplexity
,
cluster_usage
=
measure_perplexity
(
predicted_indices
,
self
.
n_classes
)
log
[
f
"
{
split
}
/perplexity"
]
=
perplexity
log
[
f
"
{
split
}
/cluster_usage"
]
=
cluster_usage
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
)
)
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
)
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
(),
}
return
d_loss
,
log
NVComposer/core/models/autoencoder.py
0 → 100755
View file @
30af93f2
import
os
import
json
from
contextlib
import
contextmanager
import
torch
import
numpy
as
np
from
einops
import
rearrange
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
pytorch_lightning
as
pl
from
pytorch_lightning.utilities
import
rank_zero_only
from
taming.modules.vqvae.quantize
import
VectorQuantizer
as
VectorQuantizer
from
core.modules.networks.ae_modules
import
Encoder
,
Decoder
from
core.distributions
import
DiagonalGaussianDistribution
from
utils.utils
import
instantiate_from_config
from
utils.save_video
import
tensor2videogrids
from
core.common
import
shape_to_str
,
gather_data
class
AutoencoderKL
(
pl
.
LightningModule
):
def
__init__
(
self
,
ddconfig
,
lossconfig
,
embed_dim
,
ckpt_path
=
None
,
ignore_keys
=
[],
image_key
=
"image"
,
colorize_nlabels
=
None
,
monitor
=
None
,
test
=
False
,
logdir
=
None
,
input_dim
=
4
,
test_args
=
None
,
):
super
().
__init__
()
self
.
image_key
=
image_key
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
loss
=
instantiate_from_config
(
lossconfig
)
assert
ddconfig
[
"double_z"
]
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
ddconfig
[
"z_channels"
],
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
self
.
embed_dim
=
embed_dim
self
.
input_dim
=
input_dim
self
.
test
=
test
self
.
test_args
=
test_args
self
.
logdir
=
logdir
if
colorize_nlabels
is
not
None
:
assert
type
(
colorize_nlabels
)
==
int
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
colorize_nlabels
,
1
,
1
))
if
monitor
is
not
None
:
self
.
monitor
=
monitor
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
if
self
.
test
:
self
.
init_test
()
def
init_test
(
self
,
):
self
.
test
=
True
save_dir
=
os
.
path
.
join
(
self
.
logdir
,
"test"
)
if
"ckpt"
in
self
.
test_args
:
ckpt_name
=
(
os
.
path
.
basename
(
self
.
test_args
.
ckpt
).
split
(
".ckpt"
)[
0
]
+
f
"_epoch
{
self
.
_cur_epoch
}
"
)
self
.
root
=
os
.
path
.
join
(
save_dir
,
ckpt_name
)
else
:
self
.
root
=
save_dir
if
"test_subdir"
in
self
.
test_args
:
self
.
root
=
os
.
path
.
join
(
save_dir
,
self
.
test_args
.
test_subdir
)
self
.
root_zs
=
os
.
path
.
join
(
self
.
root
,
"zs"
)
self
.
root_dec
=
os
.
path
.
join
(
self
.
root
,
"reconstructions"
)
self
.
root_inputs
=
os
.
path
.
join
(
self
.
root
,
"inputs"
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
if
self
.
test_args
.
save_z
:
os
.
makedirs
(
self
.
root_zs
,
exist_ok
=
True
)
if
self
.
test_args
.
save_reconstruction
:
os
.
makedirs
(
self
.
root_dec
,
exist_ok
=
True
)
if
self
.
test_args
.
save_input
:
os
.
makedirs
(
self
.
root_inputs
,
exist_ok
=
True
)
assert
self
.
test_args
is
not
None
self
.
test_maximum
=
getattr
(
self
.
test_args
,
"test_maximum"
,
None
)
# 1500 # 12000/8
self
.
count
=
0
self
.
eval_metrics
=
{}
self
.
decodes
=
[]
self
.
save_decode_samples
=
2048
if
getattr
(
self
.
test_args
,
"cal_metrics"
,
False
):
self
.
EvalLpips
=
EvalLpips
()
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
try
:
self
.
_cur_epoch
=
sd
[
"epoch"
]
sd
=
sd
[
"state_dict"
]
except
:
self
.
_cur_epoch
=
"null"
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
self
.
load_state_dict
(
sd
,
strict
=
False
)
# self.load_state_dict(sd, strict=True)
print
(
f
"Restored from
{
path
}
"
)
def
encode
(
self
,
x
,
**
kwargs
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
,
**
kwargs
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
input
,
sample_posterior
=
True
):
posterior
=
self
.
encode
(
input
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
def
get_input
(
self
,
batch
,
k
):
x
=
batch
[
k
]
# if len(x.shape) == 3:
# x = x[..., None]
# if x.dim() == 4:
# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if
x
.
dim
()
==
5
and
self
.
input_dim
==
4
:
b
,
c
,
t
,
h
,
w
=
x
.
shape
self
.
b
=
b
self
.
t
=
t
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
return
x
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
reconstructions
,
posterior
=
self
(
inputs
)
if
optimizer_idx
==
0
:
# train encoder+decoder+logvar
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
,
)
self
.
log
(
"aeloss"
,
aeloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
,
)
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
return
aeloss
if
optimizer_idx
==
1
:
# train the discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
,
)
self
.
log
(
"discloss"
,
discloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
,
)
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
return
discloss
def
validation_step
(
self
,
batch
,
batch_idx
):
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
reconstructions
,
posterior
=
self
(
inputs
)
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
0
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
,
)
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
1
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
,
)
self
.
log
(
"val/rec_loss"
,
log_dict_ae
[
"val/rec_loss"
])
self
.
log_dict
(
log_dict_ae
)
self
.
log_dict
(
log_dict_disc
)
return
self
.
log_dict
def
test_step
(
self
,
batch
,
batch_idx
):
# save z, dec
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
# forward
sample_posterior
=
True
posterior
=
self
.
encode
(
inputs
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
# logs
if
self
.
test_args
.
save_z
:
torch
.
save
(
z
,
os
.
path
.
join
(
self
.
root_zs
,
f
"zs_batch
{
batch_idx
}
_rank
{
self
.
global_rank
}
_shape
{
shape_to_str
(
z
)
}
.pt"
,
),
)
if
self
.
test_args
.
save_reconstruction
:
tensor2videogrids
(
dec
,
self
.
root_dec
,
f
"reconstructions_batch
{
batch_idx
}
_rank
{
self
.
global_rank
}
_shape
{
shape_to_str
(
z
)
}
.mp4"
,
fps
=
10
,
)
if
self
.
test_args
.
save_input
:
tensor2videogrids
(
inputs
,
self
.
root_inputs
,
f
"inputs_batch
{
batch_idx
}
_rank
{
self
.
global_rank
}
_shape
{
shape_to_str
(
z
)
}
.mp4"
,
fps
=
10
,
)
if
"save_z"
in
self
.
test_args
and
self
.
test_args
.
save_z
:
dec_np
=
(
dec
.
detach
().
cpu
().
numpy
().
transpose
(
0
,
2
,
3
,
4
,
1
)
+
1
)
/
2
*
255
dec_np
=
dec_np
.
astype
(
np
.
uint8
)
self
.
root_dec_np
=
os
.
path
.
join
(
self
.
root
,
"reconstructions_np"
)
os
.
makedirs
(
self
.
root_dec_np
,
exist_ok
=
True
)
np
.
savez
(
os
.
path
.
join
(
self
.
root_dec_np
,
f
"reconstructions_batch
{
batch_idx
}
_rank
{
self
.
global_rank
}
_shape
{
shape_to_str
(
dec_np
)
}
.npz"
,
),
dec_np
,
)
self
.
count
+=
z
.
shape
[
0
]
# misc
self
.
log
(
"batch_idx"
,
batch_idx
,
prog_bar
=
True
)
self
.
log_dict
(
self
.
eval_metrics
,
prog_bar
=
True
,
logger
=
True
)
torch
.
cuda
.
empty_cache
()
if
self
.
test_maximum
is
not
None
:
if
self
.
count
>
self
.
test_maximum
:
import
sys
sys
.
exit
()
else
:
prog
=
self
.
count
/
self
.
test_maximum
*
100
print
(
f
"Test progress:
{
prog
:.
2
f
}
% [
{
self
.
count
}
/
{
self
.
test_maximum
}
]"
)
@
rank_zero_only
def
on_test_end
(
self
):
if
self
.
test_args
.
cal_metrics
:
psnrs
,
ssims
,
ms_ssims
,
lpipses
=
[],
[],
[],
[]
n_batches
=
0
n_samples
=
0
overall
=
{}
for
k
,
v
in
self
.
eval_metrics
.
items
():
psnrs
.
append
(
v
[
"psnr"
])
ssims
.
append
(
v
[
"ssim"
])
lpipses
.
append
(
v
[
"lpips"
])
n_batches
+=
1
n_samples
+=
v
[
"n_samples"
]
mean_psnr
=
sum
(
psnrs
)
/
len
(
psnrs
)
mean_ssim
=
sum
(
ssims
)
/
len
(
ssims
)
# overall['ms_ssim'] = min(ms_ssims)
mean_lpips
=
sum
(
lpipses
)
/
len
(
lpipses
)
overall
=
{
"psnr"
:
mean_psnr
,
"ssim"
:
mean_ssim
,
"lpips"
:
mean_lpips
,
"n_batches"
:
n_batches
,
"n_samples"
:
n_samples
,
}
overall_t
=
torch
.
tensor
([
mean_psnr
,
mean_ssim
,
mean_lpips
])
# dump
for
k
,
v
in
overall
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
overall
[
k
]
=
float
(
v
)
with
open
(
os
.
path
.
join
(
self
.
root
,
f
"reconstruction_metrics.json"
),
"w"
)
as
f
:
json
.
dump
(
overall
,
f
)
f
.
close
()
def
configure_optimizers
(
self
):
lr
=
self
.
learning_rate
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
()),
lr
=
lr
,
betas
=
(
0.5
,
0.9
),
)
opt_disc
=
torch
.
optim
.
Adam
(
self
.
loss
.
discriminator
.
parameters
(),
lr
=
lr
,
betas
=
(
0.5
,
0.9
)
)
return
[
opt_ae
,
opt_disc
],
[]
def
get_last_layer
(
self
):
return
self
.
decoder
.
conv_out
.
weight
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
,
only_inputs
=
False
,
**
kwargs
):
log
=
dict
()
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
x
=
x
.
to
(
self
.
device
)
if
not
only_inputs
:
xrec
,
posterior
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
# colorize with random projection
assert
xrec
.
shape
[
1
]
>
3
x
=
self
.
to_rgb
(
x
)
xrec
=
self
.
to_rgb
(
xrec
)
log
[
"samples"
]
=
self
.
decode
(
torch
.
randn_like
(
posterior
.
sample
()))
log
[
"reconstructions"
]
=
xrec
log
[
"inputs"
]
=
x
return
log
def
to_rgb
(
self
,
x
):
assert
self
.
image_key
==
"segmentation"
if
not
hasattr
(
self
,
"colorize"
):
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
).
to
(
x
))
x
=
F
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.0
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.0
return
x
class
IdentityFirstStage
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
args
,
vq_interface
=
False
,
**
kwargs
):
self
.
vq_interface
=
vq_interface
super
().
__init__
()
def
encode
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
def
decode
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
def
quantize
(
self
,
x
,
*
args
,
**
kwargs
):
if
self
.
vq_interface
:
return
x
,
None
,
[
None
,
None
,
None
]
return
x
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
NVComposer/core/models/diffusion.py
0 → 100755
View file @
30af93f2
This diff is collapsed.
Click to expand it.
NVComposer/core/models/samplers/__init__.py
0 → 100755
View file @
30af93f2
NVComposer/core/models/samplers/ddim.py
0 → 100755
View file @
30af93f2
This diff is collapsed.
Click to expand it.
NVComposer/core/models/samplers/dpm_solver/__init__.py
0 → 100755
View file @
30af93f2
from
.sampler
import
DPMSolverSampler
\ No newline at end of file
NVComposer/core/models/samplers/dpm_solver/dpm_solver.py
0 → 100755
View file @
30af93f2
This diff is collapsed.
Click to expand it.
NVComposer/core/models/samplers/dpm_solver/sampler.py
0 → 100755
View file @
30af93f2
"""SAMPLING ONLY."""
import
torch
from
.dpm_solver
import
NoiseScheduleVP
,
model_wrapper
,
DPM_Solver
MODEL_TYPES
=
{
"eps"
:
"noise"
,
"v"
:
"v"
}
class
DPMSolverSampler
(
object
):
def
__init__
(
self
,
model
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
model
def
to_torch
(
x
):
return
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
model
.
device
)
self
.
register_buffer
(
"alphas_cumprod"
,
to_torch
(
model
.
alphas_cumprod
))
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
x_T
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
,
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
try
:
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]].
shape
[
0
]
except
:
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]][
0
].
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got
{
cbs
}
conditionings but batch-size is
{
batch_size
}
"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got
{
conditioning
.
shape
[
0
]
}
conditionings but batch-size is
{
batch_size
}
"
)
# sampling
T
,
C
,
H
,
W
=
shape
size
=
(
batch_size
,
T
,
C
,
H
,
W
)
print
(
f
"Data shape for DPM-Solver sampling is
{
size
}
, sampling steps
{
S
}
"
)
device
=
self
.
model
.
betas
.
device
if
x_T
is
None
:
img
=
torch
.
randn
(
size
,
device
=
device
)
else
:
img
=
x_T
ns
=
NoiseScheduleVP
(
"discrete"
,
alphas_cumprod
=
self
.
alphas_cumprod
)
model_fn
=
model_wrapper
(
lambda
x
,
t
,
c
:
self
.
model
.
apply_model
(
x
,
t
,
c
),
ns
,
model_type
=
MODEL_TYPES
[
self
.
model
.
parameterization
],
guidance_type
=
"classifier-free"
,
condition
=
conditioning
,
unconditional_condition
=
unconditional_conditioning
,
guidance_scale
=
unconditional_guidance_scale
,
)
dpm_solver
=
DPM_Solver
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
)
x
=
dpm_solver
.
sample
(
img
,
steps
=
S
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
2
,
lower_order_final
=
True
,
)
return
x
.
to
(
device
),
None
NVComposer/core/models/samplers/plms.py
0 → 100755
View file @
30af93f2
"""SAMPLING ONLY."""
import
numpy
as
np
from
tqdm
import
tqdm
import
torch
from
core.models.utils_diffusion
import
(
make_ddim_sampling_parameters
,
make_ddim_time_steps
,
)
from
core.common
import
noise_like
class
PLMSSampler
(
object
):
def
__init__
(
self
,
model
,
schedule
=
"linear"
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
model
self
.
ddpm_num_time_steps
=
model
.
num_time_steps
self
.
schedule
=
schedule
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.0
,
verbose
=
True
):
if
ddim_eta
!=
0
:
raise
ValueError
(
"ddim_eta must be 0 for PLMS"
)
self
.
ddim_time_steps
=
make_ddim_time_steps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_time_steps
=
ddim_num_steps
,
num_ddpm_time_steps
=
self
.
ddpm_num_time_steps
,
verbose
=
verbose
,
)
alphas_cumprod
=
self
.
model
.
alphas_cumprod
assert
(
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_time_steps
),
"alphas have to be defined for each timestep"
def
to_torch
(
x
):
return
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
self
.
register_buffer
(
"betas"
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
"alphas_cumprod"
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
"alphas_cumprod_prev"
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
"sqrt_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
()))
)
self
.
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
-
alphas_cumprod
.
cpu
())),
)
self
.
register_buffer
(
"log_one_minus_alphas_cumprod"
,
to_torch
(
np
.
log
(
1.0
-
alphas_cumprod
.
cpu
()))
)
self
.
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
/
alphas_cumprod
.
cpu
()))
)
self
.
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
/
alphas_cumprod
.
cpu
()
-
1
)),
)
# ddim sampling parameters
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_time_steps
=
self
.
ddim_time_steps
,
eta
=
ddim_eta
,
verbose
=
verbose
,
)
self
.
register_buffer
(
"ddim_sigmas"
,
ddim_sigmas
)
self
.
register_buffer
(
"ddim_alphas"
,
ddim_alphas
)
self
.
register_buffer
(
"ddim_alphas_prev"
,
ddim_alphas_prev
)
self
.
register_buffer
(
"ddim_sqrt_one_minus_alphas"
,
np
.
sqrt
(
1.0
-
ddim_alphas
))
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
(
1
-
self
.
alphas_cumprod_prev
)
/
(
1
-
self
.
alphas_cumprod
)
*
(
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
)
)
self
.
register_buffer
(
"ddim_sigmas_for_original_num_steps"
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
callback
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.0
,
mask
=
None
,
x0
=
None
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
,
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]].
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got
{
cbs
}
conditionings but batch-size is
{
batch_size
}
"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got
{
conditioning
.
shape
[
0
]
}
conditionings but batch-size is
{
batch_size
}
"
)
self
.
make_schedule
(
ddim_num_steps
=
S
,
ddim_eta
=
eta
,
verbose
=
verbose
)
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
"Data shape for PLMS sampling is
{
size
}
"
)
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
plms_sampling
(
self
,
cond
,
shape
,
x_T
=
None
,
ddim_use_original_steps
=
False
,
callback
=
None
,
time_steps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
device
)
else
:
img
=
x_T
if
time_steps
is
None
:
time_steps
=
(
self
.
ddpm_num_time_steps
if
ddim_use_original_steps
else
self
.
ddim_time_steps
)
elif
time_steps
is
not
None
and
not
ddim_use_original_steps
:
subset_end
=
(
int
(
min
(
time_steps
/
self
.
ddim_time_steps
.
shape
[
0
],
1
)
*
self
.
ddim_time_steps
.
shape
[
0
]
)
-
1
)
time_steps
=
self
.
ddim_time_steps
[:
subset_end
]
intermediates
=
{
"x_inter"
:
[
img
],
"pred_x0"
:
[
img
]}
time_range
=
(
list
(
reversed
(
range
(
0
,
time_steps
)))
if
ddim_use_original_steps
else
np
.
flip
(
time_steps
)
)
total_steps
=
time_steps
if
ddim_use_original_steps
else
time_steps
.
shape
[
0
]
print
(
f
"Running PLMS Sampling with
{
total_steps
}
time_steps"
)
iterator
=
tqdm
(
time_range
,
desc
=
"PLMS Sampler"
,
total
=
total_steps
)
old_eps
=
[]
for
i
,
step
in
enumerate
(
iterator
):
index
=
total_steps
-
i
-
1
ts
=
torch
.
full
((
b
,),
step
,
device
=
device
,
dtype
=
torch
.
long
)
ts_next
=
torch
.
full
(
(
b
,),
time_range
[
min
(
i
+
1
,
len
(
time_range
)
-
1
)],
device
=
device
,
dtype
=
torch
.
long
,
)
if
mask
is
not
None
:
assert
x0
is
not
None
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
img
=
img_orig
*
mask
+
(
1.0
-
mask
)
*
img
outs
=
self
.
p_sample_plms
(
img
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
old_eps
=
old_eps
,
t_next
=
ts_next
,
)
img
,
pred_x0
,
e_t
=
outs
old_eps
.
append
(
e_t
)
if
len
(
old_eps
)
>=
4
:
old_eps
.
pop
(
0
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
intermediates
[
"x_inter"
].
append
(
img
)
intermediates
[
"pred_x0"
].
append
(
pred_x0
)
return
img
,
intermediates
@
torch
.
no_grad
()
def
p_sample_plms
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
,
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
def
get_model_output
(
x
,
t
):
if
(
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.0
):
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
t_in
=
torch
.
cat
([
t
]
*
2
)
c_in
=
torch
.
cat
([
unconditional_conditioning
,
c
])
e_t_uncond
,
e_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
e_t
=
e_t_uncond
+
unconditional_guidance_scale
*
(
e_t
-
e_t_uncond
)
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
return
e_t
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas_prev
=
(
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
)
sqrt_one_minus_alphas
=
(
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
)
sigmas
=
(
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
)
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
# select parameters corresponding to the currently considered timestep
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
(
(
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
if
quantize_denoised
:
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
# direction pointing to x_t
dir_xt
=
(
1.0
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.0
:
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
x_prev
,
pred_x0
e_t
=
get_model_output
(
x
,
t
)
if
len
(
old_eps
)
==
0
:
# Pseudo Improved Euler (2nd order)
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t
,
index
)
e_t_next
=
get_model_output
(
x_prev
,
t_next
)
e_t_prime
=
(
e_t
+
e_t_next
)
/
2
elif
len
(
old_eps
)
==
1
:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
3
*
e_t
-
old_eps
[
-
1
])
/
2
elif
len
(
old_eps
)
==
2
:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
23
*
e_t
-
16
*
old_eps
[
-
1
]
+
5
*
old_eps
[
-
2
])
/
12
elif
len
(
old_eps
)
>=
3
:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
55
*
e_t
-
59
*
old_eps
[
-
1
]
+
37
*
old_eps
[
-
2
]
-
9
*
old_eps
[
-
3
]
)
/
24
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t_prime
,
index
)
return
x_prev
,
pred_x0
,
e_t
NVComposer/core/models/samplers/uni_pc/__init__.py
0 → 100644
View file @
30af93f2
NVComposer/core/models/samplers/uni_pc/sampler.py
0 → 100644
View file @
30af93f2
"""SAMPLING ONLY."""
import
torch
from
.uni_pc
import
NoiseScheduleVP
,
model_wrapper
,
UniPC
class
UniPCSampler
(
object
):
def
__init__
(
self
,
model
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
model
def
to_torch
(
x
):
return
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
model
.
device
)
self
.
register_buffer
(
"alphas_cumprod"
,
to_torch
(
model
.
alphas_cumprod
))
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
x_T
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
):
# sampling
T
,
C
,
H
,
W
=
shape
size
=
(
batch_size
,
T
,
C
,
H
,
W
)
device
=
self
.
model
.
betas
.
device
if
x_T
is
None
:
img
=
torch
.
randn
(
size
,
device
=
device
)
else
:
img
=
x_T
ns
=
NoiseScheduleVP
(
"discrete"
,
alphas_cumprod
=
self
.
alphas_cumprod
)
model_fn
=
model_wrapper
(
lambda
x
,
t
,
c
:
self
.
model
.
apply_model
(
x
,
t
,
c
),
ns
,
model_type
=
"v"
,
guidance_type
=
"classifier-free"
,
condition
=
conditioning
,
unconditional_condition
=
unconditional_conditioning
,
guidance_scale
=
unconditional_guidance_scale
,
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
)
x
=
uni_pc
.
sample
(
img
,
steps
=
S
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
2
,
lower_order_final
=
True
,
)
return
x
.
to
(
device
),
None
NVComposer/core/models/samplers/uni_pc/uni_pc.py
0 → 100644
View file @
30af93f2
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment