Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
FlashVideo_pytorch
Commits
3b804999
Commit
3b804999
authored
Feb 20, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2420
failed with stages
in 0 seconds
Changes
146
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4753 additions
and
0 deletions
+4753
-0
flashvideo/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc
...ffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
...modules/diffusionmodules/__pycache__/util.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc
...les/diffusionmodules/__pycache__/wrappers.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/diffusionmodules/denoiser.py
flashvideo/sgm/modules/diffusionmodules/denoiser.py
+77
-0
flashvideo/sgm/modules/diffusionmodules/denoiser_scaling.py
flashvideo/sgm/modules/diffusionmodules/denoiser_scaling.py
+77
-0
flashvideo/sgm/modules/diffusionmodules/denoiser_weighting.py
...hvideo/sgm/modules/diffusionmodules/denoiser_weighting.py
+28
-0
flashvideo/sgm/modules/diffusionmodules/discretizer.py
flashvideo/sgm/modules/diffusionmodules/discretizer.py
+152
-0
flashvideo/sgm/modules/diffusionmodules/guiders.py
flashvideo/sgm/modules/diffusionmodules/guiders.py
+130
-0
flashvideo/sgm/modules/diffusionmodules/lora.py
flashvideo/sgm/modules/diffusionmodules/lora.py
+401
-0
flashvideo/sgm/modules/diffusionmodules/loss.py
flashvideo/sgm/modules/diffusionmodules/loss.py
+171
-0
flashvideo/sgm/modules/diffusionmodules/model.py
flashvideo/sgm/modules/diffusionmodules/model.py
+779
-0
flashvideo/sgm/modules/diffusionmodules/openaimodel.py
flashvideo/sgm/modules/diffusionmodules/openaimodel.py
+1271
-0
flashvideo/sgm/modules/diffusionmodules/sampling.py
flashvideo/sgm/modules/diffusionmodules/sampling.py
+980
-0
flashvideo/sgm/modules/diffusionmodules/sampling_utils.py
flashvideo/sgm/modules/diffusionmodules/sampling_utils.py
+174
-0
flashvideo/sgm/modules/diffusionmodules/sigma_sampling.py
flashvideo/sgm/modules/diffusionmodules/sigma_sampling.py
+94
-0
flashvideo/sgm/modules/diffusionmodules/util.py
flashvideo/sgm/modules/diffusionmodules/util.py
+370
-0
flashvideo/sgm/modules/diffusionmodules/wrappers.py
flashvideo/sgm/modules/diffusionmodules/wrappers.py
+49
-0
flashvideo/sgm/modules/distributions/__init__.py
flashvideo/sgm/modules/distributions/__init__.py
+0
-0
flashvideo/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc
...odules/distributions/__pycache__/__init__.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc
...s/distributions/__pycache__/distributions.cpython-310.pyc
+0
-0
No files found.
flashvideo/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/diffusionmodules/denoiser.py
0 → 100644
View file @
3b804999
from
typing
import
Dict
,
Union
import
torch
import
torch.nn
as
nn
from
...util
import
append_dims
,
instantiate_from_config
class
Denoiser
(
nn
.
Module
):
def
__init__
(
self
,
weighting_config
,
scaling_config
):
super
().
__init__
()
self
.
weighting
=
instantiate_from_config
(
weighting_config
)
self
.
scaling
=
instantiate_from_config
(
scaling_config
)
def
possibly_quantize_sigma
(
self
,
sigma
):
return
sigma
def
possibly_quantize_c_noise
(
self
,
c_noise
):
return
c_noise
def
w
(
self
,
sigma
):
return
self
.
weighting
(
sigma
)
def
forward
(
self
,
network
:
nn
.
Module
,
input
:
torch
.
Tensor
,
sigma
:
torch
.
Tensor
,
cond
:
Dict
,
**
additional_model_inputs
,
)
->
torch
.
Tensor
:
sigma
=
self
.
possibly_quantize_sigma
(
sigma
)
sigma_shape
=
sigma
.
shape
sigma
=
append_dims
(
sigma
,
input
.
ndim
)
c_skip
,
c_out
,
c_in
,
c_noise
=
self
.
scaling
(
sigma
,
**
additional_model_inputs
)
c_noise
=
self
.
possibly_quantize_c_noise
(
c_noise
.
reshape
(
sigma_shape
))
return
network
(
input
*
c_in
,
c_noise
,
cond
,
**
additional_model_inputs
)
*
c_out
+
input
*
c_skip
class
DiscreteDenoiser
(
Denoiser
):
def
__init__
(
self
,
weighting_config
,
scaling_config
,
num_idx
,
discretization_config
,
do_append_zero
=
False
,
quantize_c_noise
=
True
,
flip
=
True
,
):
super
().
__init__
(
weighting_config
,
scaling_config
)
sigmas
=
instantiate_from_config
(
discretization_config
)(
num_idx
,
do_append_zero
=
do_append_zero
,
flip
=
flip
)
self
.
sigmas
=
sigmas
# self.register_buffer("sigmas", sigmas)
self
.
quantize_c_noise
=
quantize_c_noise
def
sigma_to_idx
(
self
,
sigma
):
dists
=
sigma
-
self
.
sigmas
.
to
(
sigma
.
device
)[:,
None
]
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
def
idx_to_sigma
(
self
,
idx
):
return
self
.
sigmas
.
to
(
idx
.
device
)[
idx
]
def
possibly_quantize_sigma
(
self
,
sigma
):
return
self
.
idx_to_sigma
(
self
.
sigma_to_idx
(
sigma
))
def
possibly_quantize_c_noise
(
self
,
c_noise
):
if
self
.
quantize_c_noise
:
return
self
.
sigma_to_idx
(
c_noise
)
else
:
return
c_noise
flashvideo/sgm/modules/diffusionmodules/denoiser_scaling.py
0 → 100644
View file @
3b804999
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Tuple
import
torch
class
DenoiserScaling
(
ABC
):
@
abstractmethod
def
__call__
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
pass
class
EDMScaling
:
def
__init__
(
self
,
sigma_data
:
float
=
0.5
):
self
.
sigma_data
=
sigma_data
def
__call__
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
c_skip
=
self
.
sigma_data
**
2
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
c_out
=
sigma
*
self
.
sigma_data
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
**
0.5
c_in
=
1
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
**
0.5
c_noise
=
0.25
*
sigma
.
log
()
return
c_skip
,
c_out
,
c_in
,
c_noise
class
EpsScaling
:
def
__call__
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
c_skip
=
torch
.
ones_like
(
sigma
,
device
=
sigma
.
device
)
c_out
=
-
sigma
c_in
=
1
/
(
sigma
**
2
+
1.0
)
**
0.5
c_noise
=
sigma
.
clone
()
return
c_skip
,
c_out
,
c_in
,
c_noise
class
VScaling
:
def
__call__
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
c_skip
=
1.0
/
(
sigma
**
2
+
1.0
)
c_out
=
-
sigma
/
(
sigma
**
2
+
1.0
)
**
0.5
c_in
=
1.0
/
(
sigma
**
2
+
1.0
)
**
0.5
c_noise
=
sigma
.
clone
()
return
c_skip
,
c_out
,
c_in
,
c_noise
class
VScalingWithEDMcNoise
(
DenoiserScaling
):
def
__call__
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
c_skip
=
1.0
/
(
sigma
**
2
+
1.0
)
c_out
=
-
sigma
/
(
sigma
**
2
+
1.0
)
**
0.5
c_in
=
1.0
/
(
sigma
**
2
+
1.0
)
**
0.5
c_noise
=
0.25
*
sigma
.
log
()
return
c_skip
,
c_out
,
c_in
,
c_noise
class
VideoScaling
:
# similar to VScaling
def
__call__
(
self
,
alphas_cumprod_sqrt
:
torch
.
Tensor
,
**
additional_model_inputs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
c_skip
=
alphas_cumprod_sqrt
c_out
=
-
((
1
-
alphas_cumprod_sqrt
**
2
)
**
0.5
)
c_in
=
torch
.
ones_like
(
alphas_cumprod_sqrt
,
device
=
alphas_cumprod_sqrt
.
device
)
c_noise
=
additional_model_inputs
[
'idx'
].
clone
()
return
c_skip
,
c_out
,
c_in
,
c_noise
flashvideo/sgm/modules/diffusionmodules/denoiser_weighting.py
0 → 100644
View file @
3b804999
import
torch
class
UnitWeighting
:
def
__call__
(
self
,
sigma
):
return
torch
.
ones_like
(
sigma
,
device
=
sigma
.
device
)
class
EDMWeighting
:
def
__init__
(
self
,
sigma_data
=
0.5
):
self
.
sigma_data
=
sigma_data
def
__call__
(
self
,
sigma
):
return
(
sigma
**
2
+
self
.
sigma_data
**
2
)
/
(
sigma
*
self
.
sigma_data
)
**
2
class
VWeighting
(
EDMWeighting
):
def
__init__
(
self
):
super
().
__init__
(
sigma_data
=
1.0
)
class
EpsWeighting
:
def
__call__
(
self
,
sigma
):
return
sigma
**-
2.0
flashvideo/sgm/modules/diffusionmodules/discretizer.py
0 → 100644
View file @
3b804999
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
torch
from
...modules.diffusionmodules.util
import
make_beta_schedule
from
...util
import
append_zero
def
generate_roughly_equally_spaced_steps
(
num_substeps
:
int
,
max_step
:
int
)
->
np
.
ndarray
:
return
np
.
linspace
(
max_step
-
1
,
0
,
num_substeps
,
endpoint
=
False
).
astype
(
int
)[::
-
1
]
class
Discretization
:
def
__call__
(
self
,
n
,
do_append_zero
=
True
,
device
=
'cpu'
,
flip
=
False
,
return_idx
=
False
):
if
return_idx
:
sigmas
,
idx
=
self
.
get_sigmas
(
n
,
device
=
device
,
return_idx
=
return_idx
)
else
:
sigmas
=
self
.
get_sigmas
(
n
,
device
=
device
,
return_idx
=
return_idx
)
sigmas
=
append_zero
(
sigmas
)
if
do_append_zero
else
sigmas
if
return_idx
:
return
sigmas
if
not
flip
else
torch
.
flip
(
sigmas
,
(
0
,
)),
idx
else
:
return
sigmas
if
not
flip
else
torch
.
flip
(
sigmas
,
(
0
,
))
@
abstractmethod
def
get_sigmas
(
self
,
n
,
device
):
pass
class
EDMDiscretization
(
Discretization
):
def
__init__
(
self
,
sigma_min
=
0.002
,
sigma_max
=
80.0
,
rho
=
7.0
):
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
rho
=
rho
def
get_sigmas
(
self
,
n
,
device
=
'cpu'
):
ramp
=
torch
.
linspace
(
0
,
1
,
n
,
device
=
device
)
min_inv_rho
=
self
.
sigma_min
**
(
1
/
self
.
rho
)
max_inv_rho
=
self
.
sigma_max
**
(
1
/
self
.
rho
)
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
self
.
rho
return
sigmas
class
LegacyDDPMDiscretization
(
Discretization
):
def
__init__
(
self
,
linear_start
=
0.00085
,
linear_end
=
0.0120
,
num_timesteps
=
1000
,
):
super
().
__init__
()
self
.
num_timesteps
=
num_timesteps
betas
=
make_beta_schedule
(
'linear'
,
num_timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
)
alphas
=
1.0
-
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
)
def
get_sigmas
(
self
,
n
,
device
=
'cpu'
):
if
n
<
self
.
num_timesteps
:
timesteps
=
generate_roughly_equally_spaced_steps
(
n
,
self
.
num_timesteps
)
alphas_cumprod
=
self
.
alphas_cumprod
[
timesteps
]
elif
n
==
self
.
num_timesteps
:
alphas_cumprod
=
self
.
alphas_cumprod
else
:
raise
ValueError
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
,
device
=
device
)
sigmas
=
to_torch
((
1
-
alphas_cumprod
)
/
alphas_cumprod
)
**
0.5
return
torch
.
flip
(
sigmas
,
(
0
,
))
# sigma_t: 14.4 -> 0.029
class
ZeroSNRDDPMDiscretization
(
Discretization
):
def
__init__
(
self
,
linear_start
=
0.00085
,
linear_end
=
0.0120
,
num_timesteps
=
1000
,
shift_scale
=
1.0
,
# noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
keep_start
=
False
,
post_shift
=
False
,
):
super
().
__init__
()
if
keep_start
and
not
post_shift
:
linear_start
=
linear_start
/
(
shift_scale
+
(
1
-
shift_scale
)
*
linear_start
)
self
.
num_timesteps
=
num_timesteps
betas
=
make_beta_schedule
(
'linear'
,
num_timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
)
alphas
=
1.0
-
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
)
# SNR shift
if
not
post_shift
:
self
.
alphas_cumprod
=
self
.
alphas_cumprod
/
(
shift_scale
+
(
1
-
shift_scale
)
*
self
.
alphas_cumprod
)
self
.
post_shift
=
post_shift
self
.
shift_scale
=
shift_scale
def
get_sigmas
(
self
,
n
,
device
=
'cpu'
,
return_idx
=
False
):
if
n
<
self
.
num_timesteps
:
timesteps
=
generate_roughly_equally_spaced_steps
(
n
,
self
.
num_timesteps
)
alphas_cumprod
=
self
.
alphas_cumprod
[
timesteps
]
elif
n
==
self
.
num_timesteps
:
alphas_cumprod
=
self
.
alphas_cumprod
else
:
raise
ValueError
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
,
device
=
device
)
alphas_cumprod
=
to_torch
(
alphas_cumprod
)
alphas_cumprod_sqrt
=
alphas_cumprod
.
sqrt
()
alphas_cumprod_sqrt_0
=
alphas_cumprod_sqrt
[
0
].
clone
()
alphas_cumprod_sqrt_T
=
alphas_cumprod_sqrt
[
-
1
].
clone
()
alphas_cumprod_sqrt
-=
alphas_cumprod_sqrt_T
alphas_cumprod_sqrt
*=
alphas_cumprod_sqrt_0
/
(
alphas_cumprod_sqrt_0
-
alphas_cumprod_sqrt_T
)
if
self
.
post_shift
:
alphas_cumprod_sqrt
=
(
alphas_cumprod_sqrt
**
2
/
(
self
.
shift_scale
+
(
1
-
self
.
shift_scale
)
*
alphas_cumprod_sqrt
**
2
))
**
0.5
if
return_idx
:
return
torch
.
flip
(
alphas_cumprod_sqrt
,
(
0
,
)),
timesteps
else
:
return
torch
.
flip
(
alphas_cumprod_sqrt
,
(
0
,
))
# sqrt(alpha_t): 0 -> 0.99
flashvideo/sgm/modules/diffusionmodules/guiders.py
0 → 100644
View file @
3b804999
import
logging
import
math
from
abc
import
ABC
,
abstractmethod
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
einops
import
rearrange
,
repeat
from
...util
import
append_dims
,
default
,
instantiate_from_config
class
Guider
(
ABC
):
@
abstractmethod
def
__call__
(
self
,
x
:
torch
.
Tensor
,
sigma
:
float
)
->
torch
.
Tensor
:
pass
def
prepare_inputs
(
self
,
x
:
torch
.
Tensor
,
s
:
float
,
c
:
Dict
,
uc
:
Dict
)
->
Tuple
[
torch
.
Tensor
,
float
,
Dict
]:
pass
class
VanillaCFG
:
"""
implements parallelized CFG
"""
def
__init__
(
self
,
scale
,
dyn_thresh_config
=
None
):
self
.
scale
=
scale
scale_schedule
=
lambda
scale
,
sigma
:
scale
# independent of step
self
.
scale_schedule
=
partial
(
scale_schedule
,
scale
)
self
.
dyn_thresh
=
instantiate_from_config
(
default
(
dyn_thresh_config
,
{
'target'
:
'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
},
))
def
__call__
(
self
,
x
,
sigma
,
scale
=
None
):
x_u
,
x_c
=
x
.
chunk
(
2
)
scale_value
=
default
(
scale
,
self
.
scale_schedule
(
sigma
))
x_pred
=
self
.
dyn_thresh
(
x_u
,
x_c
,
scale_value
)
return
x_pred
def
prepare_inputs
(
self
,
x
,
s
,
c
,
uc
):
c_out
=
dict
()
for
k
in
c
:
if
k
in
[
'vector'
,
'crossattn'
,
'concat'
]:
c_out
[
k
]
=
torch
.
cat
((
uc
[
k
],
c
[
k
]),
0
)
else
:
assert
c
[
k
]
==
uc
[
k
]
c_out
[
k
]
=
c
[
k
]
return
torch
.
cat
([
x
]
*
2
),
torch
.
cat
([
s
]
*
2
),
c_out
# class DynamicCFG(VanillaCFG):
# def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
# super().__init__(scale, dyn_thresh_config)
# scale_schedule = (lambda scale, sigma, step_index: 1 + scale *
# (1 - math.cos(math.pi *
# (step_index / num_steps)**exp)) / 2)
# self.scale_schedule = partial(scale_schedule, scale)
# self.dyn_thresh = instantiate_from_config(
# default(
# dyn_thresh_config,
# {
# 'target':
# 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
# },
# ))
# def __call__(self, x, sigma, step_index, scale=None):
# x_u, x_c = x.chunk(2)
# scale_value = self.scale_schedule(sigma, step_index.item())
# x_pred = self.dyn_thresh(x_u, x_c, scale_value)
# return x_pred
class
DynamicCFG
(
VanillaCFG
):
def
__init__
(
self
,
scale
,
exp
,
num_steps
,
dyn_thresh_config
=
None
):
super
().
__init__
(
scale
,
dyn_thresh_config
)
self
.
scale
=
scale
self
.
num_steps
=
num_steps
self
.
exp
=
exp
scale_schedule
=
(
lambda
scale
,
sigma
,
step_index
:
1
+
scale
*
(
1
-
math
.
cos
(
math
.
pi
*
(
step_index
/
num_steps
)
**
exp
))
/
2
)
#self.scale_schedule = partial(scale_schedule, scale)
self
.
dyn_thresh
=
instantiate_from_config
(
default
(
dyn_thresh_config
,
{
'target'
:
'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
},
))
def
scale_schedule_dy
(
self
,
sigma
,
step_index
):
# print(self.scale)
return
1
+
self
.
scale
*
(
1
-
math
.
cos
(
math
.
pi
*
(
step_index
/
self
.
num_steps
)
**
self
.
exp
))
/
2
def
__call__
(
self
,
x
,
sigma
,
step_index
,
scale
=
None
):
x_u
,
x_c
=
x
.
chunk
(
2
)
scale_value
=
self
.
scale_schedule_dy
(
sigma
,
step_index
.
item
())
x_pred
=
self
.
dyn_thresh
(
x_u
,
x_c
,
scale_value
)
return
x_pred
class
IdentityGuider
:
def
__call__
(
self
,
x
,
sigma
):
return
x
def
prepare_inputs
(
self
,
x
,
s
,
c
,
uc
):
c_out
=
dict
()
for
k
in
c
:
c_out
[
k
]
=
c
[
k
]
return
x
,
s
,
c_out
flashvideo/sgm/modules/diffusionmodules/lora.py
0 → 100644
View file @
3b804999
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
LoRALinearLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
rank
=
4
,
network_alpha
=
None
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
self
.
down
=
nn
.
Linear
(
in_features
,
rank
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
)
self
.
up
=
nn
.
Linear
(
rank
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self
.
network_alpha
=
network_alpha
self
.
rank
=
rank
self
.
out_features
=
out_features
self
.
in_features
=
in_features
nn
.
init
.
normal_
(
self
.
down
.
weight
,
std
=
1
/
rank
)
nn
.
init
.
zeros_
(
self
.
up
.
weight
)
def
forward
(
self
,
hidden_states
):
orig_dtype
=
hidden_states
.
dtype
dtype
=
self
.
down
.
weight
.
dtype
down_hidden_states
=
self
.
down
(
hidden_states
.
to
(
dtype
))
up_hidden_states
=
self
.
up
(
down_hidden_states
)
if
self
.
network_alpha
is
not
None
:
up_hidden_states
*=
self
.
network_alpha
/
self
.
rank
return
up_hidden_states
.
to
(
orig_dtype
)
class
LoRAConv2dLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
rank
=
4
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
padding
=
0
,
network_alpha
=
None
):
super
().
__init__
()
self
.
down
=
nn
.
Conv2d
(
in_features
,
rank
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self
.
up
=
nn
.
Conv2d
(
rank
,
out_features
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self
.
network_alpha
=
network_alpha
self
.
rank
=
rank
nn
.
init
.
normal_
(
self
.
down
.
weight
,
std
=
1
/
rank
)
nn
.
init
.
zeros_
(
self
.
up
.
weight
)
def
forward
(
self
,
hidden_states
):
orig_dtype
=
hidden_states
.
dtype
dtype
=
self
.
down
.
weight
.
dtype
down_hidden_states
=
self
.
down
(
hidden_states
.
to
(
dtype
))
up_hidden_states
=
self
.
up
(
down_hidden_states
)
if
self
.
network_alpha
is
not
None
:
up_hidden_states
*=
self
.
network_alpha
/
self
.
rank
return
up_hidden_states
.
to
(
orig_dtype
)
class
LoRACompatibleConv
(
nn
.
Conv2d
):
"""
A convolutional layer that can be used with LoRA.
"""
def
__init__
(
self
,
*
args
,
lora_layer
:
Optional
[
LoRAConv2dLayer
]
=
None
,
scale
:
float
=
1.0
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
lora_layer
=
lora_layer
self
.
scale
=
scale
def
set_lora_layer
(
self
,
lora_layer
:
Optional
[
LoRAConv2dLayer
]):
self
.
lora_layer
=
lora_layer
def
_fuse_lora
(
self
,
lora_scale
=
1.0
):
if
self
.
lora_layer
is
None
:
return
dtype
,
device
=
self
.
weight
.
data
.
dtype
,
self
.
weight
.
data
.
device
w_orig
=
self
.
weight
.
data
.
float
()
w_up
=
self
.
lora_layer
.
up
.
weight
.
data
.
float
()
w_down
=
self
.
lora_layer
.
down
.
weight
.
data
.
float
()
if
self
.
lora_layer
.
network_alpha
is
not
None
:
w_up
=
w_up
*
self
.
lora_layer
.
network_alpha
/
self
.
lora_layer
.
rank
fusion
=
torch
.
mm
(
w_up
.
flatten
(
start_dim
=
1
),
w_down
.
flatten
(
start_dim
=
1
))
fusion
=
fusion
.
reshape
(
w_orig
.
shape
)
fused_weight
=
w_orig
+
(
lora_scale
*
fusion
)
self
.
weight
.
data
=
fused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
# we can drop the lora layer now
self
.
lora_layer
=
None
# offload the up and down matrices to CPU to not blow the memory
self
.
w_up
=
w_up
.
cpu
()
self
.
w_down
=
w_down
.
cpu
()
self
.
_lora_scale
=
lora_scale
def
_unfuse_lora
(
self
):
if
not
(
hasattr
(
self
,
'w_up'
)
and
hasattr
(
self
,
'w_down'
)):
return
fused_weight
=
self
.
weight
.
data
dtype
,
device
=
fused_weight
.
data
.
dtype
,
fused_weight
.
data
.
device
self
.
w_up
=
self
.
w_up
.
to
(
device
=
device
).
float
()
self
.
w_down
=
self
.
w_down
.
to
(
device
).
float
()
fusion
=
torch
.
mm
(
self
.
w_up
.
flatten
(
start_dim
=
1
),
self
.
w_down
.
flatten
(
start_dim
=
1
))
fusion
=
fusion
.
reshape
(
fused_weight
.
shape
)
unfused_weight
=
fused_weight
.
float
()
-
(
self
.
_lora_scale
*
fusion
)
self
.
weight
.
data
=
unfused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
w_up
=
None
self
.
w_down
=
None
def
forward
(
self
,
hidden_states
,
scale
:
float
=
None
):
if
scale
is
None
:
scale
=
self
.
scale
if
self
.
lora_layer
is
None
:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return
F
.
conv2d
(
hidden_states
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
else
:
return
super
().
forward
(
hidden_states
)
+
(
scale
*
self
.
lora_layer
(
hidden_states
))
class
LoRACompatibleLinear
(
nn
.
Linear
):
"""
A Linear layer that can be used with LoRA.
"""
def
__init__
(
self
,
*
args
,
lora_layer
:
Optional
[
LoRALinearLayer
]
=
None
,
scale
:
float
=
1.0
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
lora_layer
=
lora_layer
self
.
scale
=
scale
def
set_lora_layer
(
self
,
lora_layer
:
Optional
[
LoRALinearLayer
]):
self
.
lora_layer
=
lora_layer
def
_fuse_lora
(
self
,
lora_scale
=
1.0
):
if
self
.
lora_layer
is
None
:
return
dtype
,
device
=
self
.
weight
.
data
.
dtype
,
self
.
weight
.
data
.
device
w_orig
=
self
.
weight
.
data
.
float
()
w_up
=
self
.
lora_layer
.
up
.
weight
.
data
.
float
()
w_down
=
self
.
lora_layer
.
down
.
weight
.
data
.
float
()
if
self
.
lora_layer
.
network_alpha
is
not
None
:
w_up
=
w_up
*
self
.
lora_layer
.
network_alpha
/
self
.
lora_layer
.
rank
fused_weight
=
w_orig
+
(
lora_scale
*
torch
.
bmm
(
w_up
[
None
,
:],
w_down
[
None
,
:])[
0
])
self
.
weight
.
data
=
fused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
# we can drop the lora layer now
self
.
lora_layer
=
None
# offload the up and down matrices to CPU to not blow the memory
self
.
w_up
=
w_up
.
cpu
()
self
.
w_down
=
w_down
.
cpu
()
self
.
_lora_scale
=
lora_scale
def
_unfuse_lora
(
self
):
if
not
(
hasattr
(
self
,
'w_up'
)
and
hasattr
(
self
,
'w_down'
)):
return
fused_weight
=
self
.
weight
.
data
dtype
,
device
=
fused_weight
.
dtype
,
fused_weight
.
device
w_up
=
self
.
w_up
.
to
(
device
=
device
).
float
()
w_down
=
self
.
w_down
.
to
(
device
).
float
()
unfused_weight
=
fused_weight
.
float
()
-
(
self
.
_lora_scale
*
torch
.
bmm
(
w_up
[
None
,
:],
w_down
[
None
,
:])[
0
])
self
.
weight
.
data
=
unfused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
w_up
=
None
self
.
w_down
=
None
def
forward
(
self
,
hidden_states
,
scale
:
float
=
None
):
if
scale
is
None
:
scale
=
self
.
scale
if
self
.
lora_layer
is
None
:
out
=
super
().
forward
(
hidden_states
)
return
out
else
:
out
=
super
().
forward
(
hidden_states
)
+
(
scale
*
self
.
lora_layer
(
hidden_states
))
return
out
def
_find_children
(
model
,
search_class
:
List
[
Type
[
nn
.
Module
]]
=
[
nn
.
Linear
],
):
"""
Find all modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
"""
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for
parent
in
model
.
modules
():
for
name
,
module
in
parent
.
named_children
():
if
any
([
isinstance
(
module
,
_class
)
for
_class
in
search_class
]):
yield
parent
,
name
,
module
def
_find_modules_v2
(
model
,
ancestor_class
:
Optional
[
Set
[
str
]]
=
None
,
search_class
:
List
[
Type
[
nn
.
Module
]]
=
[
nn
.
Linear
],
exclude_children_of
:
Optional
[
List
[
Type
[
nn
.
Module
]]]
=
[
LoRACompatibleLinear
,
LoRACompatibleConv
,
LoRALinearLayer
,
LoRAConv2dLayer
,
],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
indirect descendants of other modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
"""
# Get the targets we should replace all linears under
if
ancestor_class
is
not
None
:
ancestors
=
(
module
for
module
in
model
.
modules
()
if
module
.
__class__
.
__name__
in
ancestor_class
)
else
:
# this, incase you want to naively iterate over all modules.
ancestors
=
[
module
for
module
in
model
.
modules
()]
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for
ancestor
in
ancestors
:
for
fullname
,
module
in
ancestor
.
named_modules
():
if
any
([
isinstance
(
module
,
_class
)
for
_class
in
search_class
]):
# Find the direct parent if this is a descendant, not a child, of target
*
path
,
name
=
fullname
.
split
(
'.'
)
parent
=
ancestor
flag
=
False
while
path
:
try
:
parent
=
parent
.
get_submodule
(
path
.
pop
(
0
))
except
:
flag
=
True
break
if
flag
:
continue
# Skip this linear if it's a child of a LoraInjectedLinear
if
exclude_children_of
and
any
([
isinstance
(
parent
,
_class
)
for
_class
in
exclude_children_of
]):
continue
# Otherwise, yield it
yield
parent
,
name
,
module
_find_modules
=
_find_modules_v2
def
inject_trainable_lora_extended
(
model
:
nn
.
Module
,
target_replace_module
:
Set
[
str
]
=
None
,
rank
:
int
=
4
,
scale
:
float
=
1.0
,
):
for
_module
,
name
,
_child_module
in
_find_modules
(
model
,
target_replace_module
,
search_class
=
[
nn
.
Linear
,
nn
.
Conv2d
]):
if
_child_module
.
__class__
==
nn
.
Linear
:
weight
=
_child_module
.
weight
bias
=
_child_module
.
bias
lora_layer
=
LoRALinearLayer
(
in_features
=
_child_module
.
in_features
,
out_features
=
_child_module
.
out_features
,
rank
=
rank
,
)
_tmp
=
(
LoRACompatibleLinear
(
_child_module
.
in_features
,
_child_module
.
out_features
,
lora_layer
=
lora_layer
,
scale
=
scale
,
).
to
(
weight
.
dtype
).
to
(
weight
.
device
))
_tmp
.
weight
=
weight
if
bias
is
not
None
:
_tmp
.
bias
=
bias
elif
_child_module
.
__class__
==
nn
.
Conv2d
:
weight
=
_child_module
.
weight
bias
=
_child_module
.
bias
lora_layer
=
LoRAConv2dLayer
(
in_features
=
_child_module
.
in_channels
,
out_features
=
_child_module
.
out_channels
,
rank
=
rank
,
kernel_size
=
_child_module
.
kernel_size
,
stride
=
_child_module
.
stride
,
padding
=
_child_module
.
padding
,
)
_tmp
=
(
LoRACompatibleConv
(
_child_module
.
in_channels
,
_child_module
.
out_channels
,
kernel_size
=
_child_module
.
kernel_size
,
stride
=
_child_module
.
stride
,
padding
=
_child_module
.
padding
,
lora_layer
=
lora_layer
,
scale
=
scale
,
).
to
(
weight
.
dtype
).
to
(
weight
.
device
))
_tmp
.
weight
=
weight
if
bias
is
not
None
:
_tmp
.
bias
=
bias
else
:
continue
_module
.
_modules
[
name
]
=
_tmp
# print('injecting lora layer to', _module, name)
return
def
update_lora_scale
(
model
:
nn
.
Module
,
target_module
:
Set
[
str
]
=
None
,
scale
:
float
=
1.0
,
):
for
_module
,
name
,
_child_module
in
_find_modules
(
model
,
target_module
,
search_class
=
[
LoRACompatibleLinear
,
LoRACompatibleConv
]):
_child_module
.
scale
=
scale
return
flashvideo/sgm/modules/diffusionmodules/loss.py
0 → 100644
View file @
3b804999
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
omegaconf
import
ListConfig
from
sat
import
mpu
from
...modules.autoencoding.lpips.loss.lpips
import
LPIPS
from
...util
import
append_dims
,
instantiate_from_config
class
StandardDiffusionLoss
(
nn
.
Module
):
def
__init__
(
self
,
sigma_sampler_config
,
type
=
'l2'
,
offset_noise_level
=
0.0
,
batch2model_keys
:
Optional
[
Union
[
str
,
List
[
str
],
ListConfig
]]
=
None
,
):
super
().
__init__
()
assert
type
in
[
'l2'
,
'l1'
,
'lpips'
]
self
.
sigma_sampler
=
instantiate_from_config
(
sigma_sampler_config
)
self
.
type
=
type
self
.
offset_noise_level
=
offset_noise_level
if
type
==
'lpips'
:
self
.
lpips
=
LPIPS
().
eval
()
if
not
batch2model_keys
:
batch2model_keys
=
[]
if
isinstance
(
batch2model_keys
,
str
):
batch2model_keys
=
[
batch2model_keys
]
self
.
batch2model_keys
=
set
(
batch2model_keys
)
def
__call__
(
self
,
network
,
denoiser
,
conditioner
,
input
,
batch
):
cond
=
conditioner
(
batch
)
additional_model_inputs
=
{
key
:
batch
[
key
]
for
key
in
self
.
batch2model_keys
.
intersection
(
batch
)
}
sigmas
=
self
.
sigma_sampler
(
input
.
shape
[
0
]).
to
(
input
.
device
)
noise
=
torch
.
randn_like
(
input
)
if
self
.
offset_noise_level
>
0.0
:
noise
=
(
noise
+
append_dims
(
torch
.
randn
(
input
.
shape
[
0
]).
to
(
input
.
device
),
input
.
ndim
)
*
self
.
offset_noise_level
)
noise
=
noise
.
to
(
input
.
dtype
)
noised_input
=
input
.
float
()
+
noise
*
append_dims
(
sigmas
,
input
.
ndim
)
model_output
=
denoiser
(
network
,
noised_input
,
sigmas
,
cond
,
**
additional_model_inputs
)
w
=
append_dims
(
denoiser
.
w
(
sigmas
),
input
.
ndim
)
return
self
.
get_loss
(
model_output
,
input
,
w
)
def
get_loss
(
self
,
model_output
,
target
,
w
):
if
self
.
type
==
'l2'
:
return
torch
.
mean
(
(
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
'l1'
:
return
torch
.
mean
((
w
*
(
model_output
-
target
).
abs
()).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
'lpips'
:
loss
=
self
.
lpips
(
model_output
,
target
).
reshape
(
-
1
)
return
loss
class
VideoDiffusionLoss
(
StandardDiffusionLoss
):
def
__init__
(
self
,
block_scale
=
None
,
block_size
=
None
,
min_snr_value
=
None
,
fixed_frames
=
0
,
**
kwargs
):
self
.
fixed_frames
=
fixed_frames
self
.
block_scale
=
block_scale
self
.
block_size
=
block_size
self
.
min_snr_value
=
min_snr_value
super
().
__init__
(
**
kwargs
)
def
__call__
(
self
,
network
,
denoiser
,
conditioner
,
input
,
batch
):
cond
=
conditioner
(
batch
)
additional_model_inputs
=
{
key
:
batch
[
key
]
for
key
in
self
.
batch2model_keys
.
intersection
(
batch
)
}
alphas_cumprod_sqrt
,
idx
=
self
.
sigma_sampler
(
input
.
shape
[
0
],
return_idx
=
True
)
#tensor([0.8585])
if
'ref_noise_step'
in
self
.
share_cache
:
print
(
self
.
share_cache
[
'ref_noise_step'
])
ref_noise_step
=
self
.
share_cache
[
'ref_noise_step'
]
ref_alphas_cumprod_sqrt
=
self
.
sigma_sampler
.
idx_to_sigma
(
torch
.
zeros
(
input
.
shape
[
0
]).
fill_
(
ref_noise_step
).
long
())
ref_alphas_cumprod_sqrt
=
ref_alphas_cumprod_sqrt
.
to
(
input
.
device
)
ref_x
=
self
.
share_cache
[
'ref_x'
]
ref_noise
=
torch
.
randn_like
(
ref_x
)
# *0.8505 + noise * 0.5128 sqrt(1-0.8505^2)**0.5
ref_noised_input
=
ref_x
*
append_dims
(
ref_alphas_cumprod_sqrt
,
ref_x
.
ndim
)
\
+
ref_noise
*
append_dims
(
(
1
-
ref_alphas_cumprod_sqrt
**
2
)
**
0.5
,
ref_x
.
ndim
)
self
.
share_cache
[
'ref_x'
]
=
ref_noised_input
alphas_cumprod_sqrt
=
alphas_cumprod_sqrt
.
to
(
input
.
device
)
idx
=
idx
.
to
(
input
.
device
)
noise
=
torch
.
randn_like
(
input
)
# broadcast noise
mp_size
=
mpu
.
get_model_parallel_world_size
()
global_rank
=
torch
.
distributed
.
get_rank
()
//
mp_size
src
=
global_rank
*
mp_size
torch
.
distributed
.
broadcast
(
idx
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
noise
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
alphas_cumprod_sqrt
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
additional_model_inputs
[
'idx'
]
=
idx
if
self
.
offset_noise_level
>
0.0
:
noise
=
(
noise
+
append_dims
(
torch
.
randn
(
input
.
shape
[
0
]).
to
(
input
.
device
),
input
.
ndim
)
*
self
.
offset_noise_level
)
noised_input
=
input
.
float
()
*
append_dims
(
alphas_cumprod_sqrt
,
input
.
ndim
)
+
noise
*
append_dims
(
(
1
-
alphas_cumprod_sqrt
**
2
)
**
0.5
,
input
.
ndim
)
if
'concat_images'
in
batch
.
keys
():
cond
[
'concat'
]
=
batch
[
'concat_images'
]
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
model_output
=
denoiser
(
network
,
noised_input
,
alphas_cumprod_sqrt
,
cond
,
**
additional_model_inputs
)
w
=
append_dims
(
1
/
(
1
-
alphas_cumprod_sqrt
**
2
),
input
.
ndim
)
# v-pred
if
self
.
min_snr_value
is
not
None
:
w
=
min
(
w
,
self
.
min_snr_value
)
return
self
.
get_loss
(
model_output
,
input
,
w
)
def
get_loss
(
self
,
model_output
,
target
,
w
):
if
self
.
type
==
'l2'
:
# model_output.shape
# torch.Size([1, 2, 16, 60, 88])
return
torch
.
mean
(
(
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
'l1'
:
return
torch
.
mean
((
w
*
(
model_output
-
target
).
abs
()).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
'lpips'
:
loss
=
self
.
lpips
(
model_output
,
target
).
reshape
(
-
1
)
return
loss
flashvideo/sgm/modules/diffusionmodules/model.py
0 → 100644
View file @
3b804999
# pytorch_diffusion + derived encoder decoder
import
math
from
typing
import
Any
,
Callable
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
packaging
import
version
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILABLE
=
True
except
:
XFORMERS_IS_AVAILABLE
=
False
print
(
"no module 'xformers'. Processing without..."
)
from
...modules.attention
import
LinearAttention
,
MemoryEfficientCrossAttention
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
'nearest'
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
'constant'
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
LinAttnBlock
(
LinearAttention
):
"""to match AttnBlock usage"""
def
__init__
(
self
,
in_channels
):
super
().
__init__
(
dim
=
in_channels
,
heads
=
1
,
dim_head
=
in_channels
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
attention
(
self
,
h_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
b
,
c
,
h
,
w
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'b c h w -> b 1 (h w) c'
).
contiguous
(),
(
q
,
k
,
v
))
h_
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
# scale is dim ** -0.5 per default
# compute attention
return
rearrange
(
h_
,
'b 1 (h w) c -> b c h w'
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
def
forward
(
self
,
x
,
**
kwargs
):
h_
=
x
h_
=
self
.
attention
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
MemoryEfficientAttnBlock
(
nn
.
Module
):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
attention_op
:
Optional
[
Any
]
=
None
def
attention
(
self
,
h_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
B
,
C
,
H
,
W
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'b c h w -> b (h w) c'
),
(
q
,
k
,
v
))
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
).
reshape
(
B
,
t
.
shape
[
1
],
1
,
C
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
*
1
,
t
.
shape
[
1
],
C
).
contiguous
(),
(
q
,
k
,
v
),
)
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
out
.
unsqueeze
(
0
).
reshape
(
B
,
1
,
out
.
shape
[
1
],
C
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
,
out
.
shape
[
1
],
C
)
return
rearrange
(
out
,
'b (h w) c -> b c h w'
,
b
=
B
,
h
=
H
,
w
=
W
,
c
=
C
)
def
forward
(
self
,
x
,
**
kwargs
):
h_
=
x
h_
=
self
.
attention
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
MemoryEfficientCrossAttentionWrapper
(
MemoryEfficientCrossAttention
):
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
**
unused_kwargs
):
b
,
c
,
h
,
w
=
x
.
shape
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
out
=
super
().
forward
(
x
,
context
=
context
,
mask
=
mask
)
out
=
rearrange
(
out
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
,
c
=
c
)
return
x
+
out
def
make_attn
(
in_channels
,
attn_type
=
'vanilla'
,
attn_kwargs
=
None
):
assert
attn_type
in
[
'vanilla'
,
'vanilla-xformers'
,
'memory-efficient-cross-attn'
,
'linear'
,
'none'
,
],
f
'attn_type
{
attn_type
}
unknown'
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
)
and
attn_type
!=
'none'
:
assert
XFORMERS_IS_AVAILABLE
,
(
f
'We do not support vanilla attention in
{
torch
.
__version__
}
anymore, '
f
"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type
=
'vanilla-xformers'
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
'vanilla'
:
assert
attn_kwargs
is
None
return
AttnBlock
(
in_channels
)
elif
attn_type
==
'vanilla-xformers'
:
print
(
f
'building MemoryEfficientAttnBlock with
{
in_channels
}
in_channels...'
)
return
MemoryEfficientAttnBlock
(
in_channels
)
elif
type
==
'memory-efficient-cross-attn'
:
attn_kwargs
[
'query_dim'
]
=
in_channels
return
MemoryEfficientCrossAttentionWrapper
(
**
attn_kwargs
)
elif
attn_type
==
'none'
:
return
nn
.
Identity
(
in_channels
)
else
:
return
LinAttnBlock
(
in_channels
)
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
use_linear_attn
=
False
,
attn_type
=
'vanilla'
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
'linear'
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
])
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,
)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
,
context
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
context
is
not
None
:
# assume aligned context, cat along channel axis
x
=
torch
.
cat
((
x
,
context
),
dim
=
1
)
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
def
get_last_layer
(
self
):
return
self
.
conv_out
.
weight
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
use_linear_attn
=
False
,
attn_type
=
'vanilla'
,
**
ignore_kwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
'linear'
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,
)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
)
def
forward
(
self
,
x
):
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
attn_type
=
'vanilla'
,
**
ignorekwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
'linear'
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
self
.
tanh_out
=
tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,
)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
'Working with z of shape {} = {} dimensions.'
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
make_attn_cls
=
self
.
_make_attn
()
make_resblock_cls
=
self
.
_make_resblock
()
make_conv_cls
=
self
.
_make_conv
()
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn_cls
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn_cls
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
make_conv_cls
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
_make_attn
(
self
)
->
Callable
:
return
make_attn
def
_make_resblock
(
self
)
->
Callable
:
return
ResnetBlock
def
_make_conv
(
self
)
->
Callable
:
return
torch
.
nn
.
Conv2d
def
get_last_layer
(
self
,
**
kwargs
):
return
self
.
conv_out
.
weight
def
forward
(
self
,
z
,
**
kwargs
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
,
**
kwargs
)
h
=
self
.
mid
.
attn_1
(
h
,
**
kwargs
)
h
=
self
.
mid
.
block_2
(
h
,
temb
,
**
kwargs
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
,
**
kwargs
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
,
**
kwargs
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
,
**
kwargs
)
if
self
.
tanh_out
:
h
=
torch
.
tanh
(
h
)
return
h
flashvideo/sgm/modules/diffusionmodules/openaimodel.py
0 → 100644
View file @
3b804999
import
math
import
os
from
abc
import
abstractmethod
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
...modules.attention
import
SpatialTransformer
from
...modules.diffusionmodules.lora
import
(
inject_trainable_lora_extended
,
update_lora_scale
)
from
...modules.diffusionmodules.util
import
(
avg_pool_nd
,
checkpoint
,
conv_nd
,
linear
,
normalization
,
timestep_embedding
,
zero_module
)
from
...modules.video_attention
import
SpatialVideoTransformer
from
...util
import
default
,
exists
# dummy replace
def
convert_module_to_f16
(
x
):
pass
def
convert_module_to_f32
(
x
):
pass
## go
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
th
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
th
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
:
th
.
Tensor
,
emb
:
th
.
Tensor
,
context
:
Optional
[
th
.
Tensor
]
=
None
,
image_only_indicator
:
Optional
[
th
.
Tensor
]
=
None
,
time_context
:
Optional
[
int
]
=
None
,
num_video_frames
:
Optional
[
int
]
=
None
,
):
from
...modules.diffusionmodules.video_model
import
VideoResBlock
for
layer
in
self
:
module
=
layer
if
isinstance
(
module
,
TimestepBlock
)
and
not
isinstance
(
module
,
VideoResBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
module
,
VideoResBlock
):
x
=
layer
(
x
,
emb
,
num_video_frames
,
image_only_indicator
)
elif
isinstance
(
module
,
SpatialVideoTransformer
):
x
=
layer
(
x
,
context
,
time_context
,
num_video_frames
,
image_only_indicator
,
)
elif
isinstance
(
module
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
else
:
x
=
layer
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
third_up
=
False
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
third_up
=
third_up
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
t_factor
=
1
if
not
self
.
third_up
else
2
x
=
F
.
interpolate
(
x
,
(
t_factor
*
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
'nearest'
,
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
'nearest'
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
TransposedUpsample
(
nn
.
Module
):
'Learned 2x upsampling without padding'
def
__init__
(
self
,
channels
,
out_channels
=
None
,
ks
=
5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
up
=
nn
.
ConvTranspose2d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
ks
,
stride
=
2
)
def
forward
(
self
,
x
):
return
self
.
up
(
x
)
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
third_down
=
False
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
((
1
,
2
,
2
)
if
not
third_down
else
(
2
,
2
,
2
))
if
use_conv
:
print
(
f
'Building a Downsample layer with
{
dims
}
dims.'
)
print
(
f
' --> settings are:
\n
in-chn:
{
self
.
channels
}
, out-chn:
{
self
.
out_channels
}
, '
f
'kernel-size: 3, stride:
{
stride
}
, padding:
{
padding
}
'
)
if
dims
==
3
:
print
(
f
' --> Downsampling third axis (time):
{
third_down
}
'
)
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
3
,
exchange_temb_dims
=
False
,
skip_t_emb
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
exchange_temb_dims
=
exchange_temb_dims
if
isinstance
(
kernel_size
,
Iterable
):
padding
=
[
k
//
2
for
k
in
kernel_size
]
else
:
padding
=
kernel_size
//
2
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
skip_t_emb
=
skip_t_emb
self
.
emb_out_channels
=
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
if
self
.
skip_t_emb
:
print
(
f
'Skipping timestep embedding in
{
self
.
__class__
.
__name__
}
'
)
assert
not
self
.
use_scale_shift_norm
self
.
emb_layers
=
None
self
.
exchange_temb_dims
=
False
else
:
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
self
.
emb_out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return
checkpoint
(
self
.
_forward
,
(
x
,
emb
),
self
.
parameters
(),
self
.
use_checkpoint
)
def
_forward
(
self
,
x
,
emb
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
if
self
.
skip_t_emb
:
emb_out
=
th
.
zeros_like
(
h
)
else
:
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
if
self
.
exchange_temb_dims
:
emb_out
=
rearrange
(
emb_out
,
'b t c ... -> b c t ...'
)
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
'q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
'
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
if
use_new_attention_order
:
# split qkv before split heads
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
else
:
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
**
kwargs
):
# TODO add crossframe attention and use mixed checkpoint
return
checkpoint
(
self
.
_forward
,
(
x
,
),
self
.
parameters
(),
True
)
# TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def
_forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
def
count_flops_attn
(
model
,
_x
,
y
):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b
,
c
,
*
spatial
=
y
[
0
].
shape
num_spatial
=
int
(
np
.
prod
(
spatial
))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
model
.
total_ops
+=
th
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
'bct,bcs->bts'
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
'bts,bcs->bct'
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
'bct,bcs->bts'
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
'bts,bcs->bct'
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
t
):
return
timestep_embedding
(
t
,
self
.
dim
)
str_to_dtype
=
{
'fp32'
:
th
.
float32
,
'fp16'
:
th
.
float16
,
'bf16'
:
th
.
bfloat16
}
class
UNetModel
(
nn
.
Module
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
use_spatial_transformer
=
False
,
# custom transformer support
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
disable_self_attentions
=
None
,
num_attention_blocks
=
None
,
disable_middle_self_attn
=
False
,
use_linear_in_transformer
=
False
,
spatial_transformer_attn_type
=
'softmax'
,
adm_in_channels
=
None
,
use_fairscale_checkpoint
=
False
,
offload_to_cpu
=
False
,
transformer_depth_middle
=
None
,
dtype
=
'fp32'
,
lora_init
=
False
,
lora_rank
=
4
,
lora_scale
=
1.0
,
lora_weight_path
=
None
,
):
super
().
__init__
()
from
omegaconf.listconfig
import
ListConfig
self
.
dtype
=
str_to_dtype
[
dtype
]
if
use_spatial_transformer
:
assert
(
context_dim
is
not
None
),
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if
context_dim
is
not
None
:
assert
(
use_spatial_transformer
),
'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
if
type
(
context_dim
)
==
ListConfig
:
context_dim
=
list
(
context_dim
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
elif
isinstance
(
transformer_depth
,
ListConfig
):
transformer_depth
=
list
(
transformer_depth
)
transformer_depth_middle
=
default
(
transformer_depth_middle
,
transformer_depth
[
-
1
])
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
if
len
(
num_res_blocks
)
!=
len
(
channel_mult
):
raise
ValueError
(
'provide num_res_blocks either as an int (globally constant) or '
'as a list/tuple (per-level) with the same length as channel_mult'
)
self
.
num_res_blocks
=
num_res_blocks
# self.num_res_blocks = num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
)),
))
print
(
f
'Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. '
f
'This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, '
f
'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
f
'attention will still not be set.'
)
# todo: convert to warning
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
if
use_fp16
:
print
(
'WARNING: use_fp16 was dropped and has no effect anymore.'
)
# self.dtype = th.float16 if use_fp16 else th.float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
assert
use_fairscale_checkpoint
!=
use_checkpoint
or
not
(
use_checkpoint
or
use_fairscale_checkpoint
)
self
.
use_fairscale_checkpoint
=
False
checkpoint_wrapper_fn
=
(
partial
(
checkpoint_wrapper
,
offload_to_cpu
=
offload_to_cpu
)
if
self
.
use_fairscale_checkpoint
else
lambda
x
:
x
)
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
))
if
self
.
num_classes
is
not
None
:
if
isinstance
(
self
.
num_classes
,
int
):
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
elif
self
.
num_classes
==
'continuous'
:
print
(
'setting up linear c_adm embedding layer'
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
elif
self
.
num_classes
==
'timestep'
:
self
.
label_emb
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
Timestep
(
model_channels
),
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
),
))
elif
self
.
num_classes
==
'sequential'
:
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
))
else
:
raise
ValueError
()
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))
])
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
))
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
))
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)),
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
))
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)),
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
self
.
num_res_blocks
[
level
]
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
checkpoint_wrapper_fn
(
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
))
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
))
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)))
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
out_ch
=
ch
layers
.
append
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
))
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
normalization
(
ch
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
))
if
lora_init
:
self
.
_init_lora
(
lora_rank
,
lora_scale
,
lora_weight_path
)
def
_init_lora
(
self
,
rank
,
scale
,
ckpt_dir
=
None
):
inject_trainable_lora_extended
(
self
,
target_replace_module
=
None
,
rank
=
rank
,
scale
=
scale
)
if
ckpt_dir
is
not
None
:
with
open
(
os
.
path
.
join
(
ckpt_dir
,
'latest'
))
as
latest_file
:
latest
=
latest_file
.
read
().
strip
()
ckpt_path
=
os
.
path
.
join
(
ckpt_dir
,
latest
,
'mp_rank_00_model_states.pt'
)
print
(
f
'loading lora from
{
ckpt_path
}
'
)
sd
=
th
.
load
(
ckpt_path
)[
'module'
]
sd
=
{
key
[
len
(
'model.diffusion_model'
):]:
sd
[
key
]
for
key
in
sd
if
key
.
startswith
(
'model.diffusion_model'
)
}
self
.
load_state_dict
(
sd
,
strict
=
False
)
def
_update_scale
(
self
,
scale
):
update_lora_scale
(
self
,
scale
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
self
.
output_blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
'must specify y if and only if the model is class-conditional'
hs
=
[]
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
,
dtype
=
self
.
dtype
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
[
0
]
==
x
.
shape
[
0
]
emb
=
emb
+
self
.
label_emb
(
y
)
# h = x.type(self.dtype)
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
assert
False
,
'not supported anymore. what the f*** are you doing?'
else
:
return
self
.
out
(
h
)
class
NoTimeUNetModel
(
UNetModel
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
timesteps
=
th
.
zeros_like
(
timesteps
)
return
super
().
forward
(
x
,
timesteps
,
context
,
y
,
**
kwargs
)
class
EncoderUNetModel
(
nn
.
Module
):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
pool
=
'adaptive'
,
*
args
,
**
kwargs
,
):
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))
])
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
pool
=
pool
if
pool
==
'adaptive'
:
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
zero_module
(
conv_nd
(
dims
,
ch
,
out_channels
,
1
)),
nn
.
Flatten
(),
)
elif
pool
==
'attention'
:
assert
num_head_channels
!=
-
1
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
elif
pool
==
'spatial'
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
nn
.
ReLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
elif
pool
==
'spatial_v2'
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
normalization
(
2048
),
nn
.
SiLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
else
:
raise
NotImplementedError
(
f
'Unexpected
{
pool
}
pooling'
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
results
=
[]
# h = x.type(self.dtype)
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
if
self
.
pool
.
startswith
(
'spatial'
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
self
.
middle_block
(
h
,
emb
)
if
self
.
pool
.
startswith
(
'spatial'
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
th
.
cat
(
results
,
axis
=-
1
)
return
self
.
out
(
h
)
else
:
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
if
__name__
==
'__main__'
:
class
Dummy
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
64
):
super
().
__init__
()
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
2
,
in_channels
,
model_channels
,
3
,
padding
=
1
))
])
model
=
UNetModel
(
use_checkpoint
=
True
,
image_size
=
64
,
in_channels
=
4
,
out_channels
=
4
,
model_channels
=
128
,
attention_resolutions
=
[
4
,
2
],
num_res_blocks
=
2
,
channel_mult
=
[
1
,
2
,
4
],
num_head_channels
=
64
,
use_spatial_transformer
=
False
,
use_linear_in_transformer
=
True
,
transformer_depth
=
1
,
legacy
=
False
,
).
cuda
()
x
=
th
.
randn
(
11
,
4
,
64
,
64
).
cuda
()
t
=
th
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
11
,
),
device
=
'cuda'
)
o
=
model
(
x
,
t
)
print
(
'done.'
)
flashvideo/sgm/modules/diffusionmodules/sampling.py
0 → 100644
View file @
3b804999
"""
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
from
typing
import
Dict
,
Union
import
torch
from
omegaconf
import
ListConfig
,
OmegaConf
from
tqdm
import
tqdm
from
...modules.diffusionmodules.sampling_utils
import
(
get_ancestral_step
,
linear_multistep_coeff
,
to_d
,
to_neg_log_sigma
,
to_sigma
)
from
...util
import
SeededNoise
,
append_dims
,
default
,
instantiate_from_config
from
.guiders
import
DynamicCFG
DEFAULT_GUIDER
=
{
'target'
:
'sgm.modules.diffusionmodules.guiders.IdentityGuider'
}
class
BaseDiffusionSampler
:
def
__init__
(
self
,
discretization_config
:
Union
[
Dict
,
ListConfig
,
OmegaConf
],
num_steps
:
Union
[
int
,
None
]
=
None
,
guider_config
:
Union
[
Dict
,
ListConfig
,
OmegaConf
,
None
]
=
None
,
verbose
:
bool
=
False
,
device
:
str
=
'cuda'
,
):
self
.
num_steps
=
num_steps
self
.
discretization
=
instantiate_from_config
(
discretization_config
)
self
.
guider
=
instantiate_from_config
(
default
(
guider_config
,
DEFAULT_GUIDER
,
))
self
.
verbose
=
verbose
self
.
device
=
device
def
prepare_sampling_loop
(
self
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
sigmas
=
self
.
discretization
(
self
.
num_steps
if
num_steps
is
None
else
num_steps
,
device
=
self
.
device
)
uc
=
default
(
uc
,
cond
)
x
*=
torch
.
sqrt
(
1.0
+
sigmas
[
0
]
**
2.0
)
num_sigmas
=
len
(
sigmas
)
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]]).
float
()
return
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
def
denoise
(
self
,
x
,
denoiser
,
sigma
,
cond
,
uc
):
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
sigma
,
cond
,
uc
))
denoised
=
self
.
guider
(
denoised
,
sigma
)
return
denoised
def
get_sigma_gen
(
self
,
num_sigmas
):
sigma_generator
=
range
(
num_sigmas
-
1
)
if
self
.
verbose
:
print
(
'#'
*
30
,
' Sampling setting '
,
'#'
*
30
)
print
(
f
'Sampler:
{
self
.
__class__
.
__name__
}
'
)
print
(
f
'Discretization:
{
self
.
discretization
.
__class__
.
__name__
}
'
)
print
(
f
'Guider:
{
self
.
guider
.
__class__
.
__name__
}
'
)
sigma_generator
=
tqdm
(
sigma_generator
,
total
=
num_sigmas
,
desc
=
f
'Sampling with
{
self
.
__class__
.
__name__
}
for
{
num_sigmas
}
steps'
,
)
return
sigma_generator
class
SingleStepDiffusionSampler
(
BaseDiffusionSampler
):
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
euler_step
(
self
,
x
,
d
,
dt
):
return
x
+
dt
*
d
class
EDMSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
s_churn
=
0.0
,
s_tmin
=
0.0
,
s_tmax
=
float
(
'inf'
),
s_noise
=
1.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
s_churn
=
s_churn
self
.
s_tmin
=
s_tmin
self
.
s_tmax
=
s_tmax
self
.
s_noise
=
s_noise
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
gamma
=
0.0
):
sigma_hat
=
sigma
*
(
gamma
+
1.0
)
if
gamma
>
0
:
eps
=
torch
.
randn_like
(
x
)
*
self
.
s_noise
x
=
x
+
eps
*
append_dims
(
sigma_hat
**
2
-
sigma
**
2
,
x
.
ndim
)
**
0.5
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma_hat
,
cond
,
uc
)
d
=
to_d
(
x
,
sigma_hat
,
denoised
)
dt
=
append_dims
(
next_sigma
-
sigma_hat
,
x
.
ndim
)
euler_step
=
self
.
euler_step
(
x
,
d
,
dt
)
x
=
self
.
possible_correction_step
(
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
gamma
=
(
min
(
self
.
s_churn
/
(
num_sigmas
-
1
),
2
**
0.5
-
1
)
if
self
.
s_tmin
<=
sigmas
[
i
]
<=
self
.
s_tmax
else
0.0
)
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
gamma
,
)
return
x
class
DDIMSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
s_noise
=
0.1
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
s_noise
=
s_noise
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
s_noise
=
0.0
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
d
=
to_d
(
x
,
sigma
,
denoised
)
dt
=
append_dims
(
next_sigma
*
(
1
-
s_noise
**
2
)
**
0.5
-
sigma
,
x
.
ndim
)
euler_step
=
x
+
dt
*
d
+
s_noise
*
append_dims
(
next_sigma
,
x
.
ndim
)
*
torch
.
randn_like
(
x
)
x
=
self
.
possible_correction_step
(
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
self
.
s_noise
,
)
return
x
class
AncestralSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
eta
=
1.0
,
s_noise
=
1.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
eta
=
eta
self
.
s_noise
=
s_noise
self
.
noise_sampler
=
lambda
x
:
torch
.
randn_like
(
x
)
def
ancestral_euler_step
(
self
,
x
,
denoised
,
sigma
,
sigma_down
):
d
=
to_d
(
x
,
sigma
,
denoised
)
dt
=
append_dims
(
sigma_down
-
sigma
,
x
.
ndim
)
return
self
.
euler_step
(
x
,
d
,
dt
)
def
ancestral_step
(
self
,
x
,
sigma
,
next_sigma
,
sigma_up
):
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x
+
self
.
noise_sampler
(
x
)
*
self
.
s_noise
*
append_dims
(
sigma_up
,
x
.
ndim
),
x
,
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
)
return
x
class
LinearMultistepSampler
(
BaseDiffusionSampler
):
def
__init__
(
self
,
order
=
4
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
order
=
order
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
ds
=
[]
sigmas_cpu
=
sigmas
.
detach
().
cpu
().
numpy
()
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
sigma
=
s_in
*
sigmas
[
i
]
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
sigma
,
cond
,
uc
),
**
kwargs
)
denoised
=
self
.
guider
(
denoised
,
sigma
)
d
=
to_d
(
x
,
sigma
,
denoised
)
ds
.
append
(
d
)
if
len
(
ds
)
>
self
.
order
:
ds
.
pop
(
0
)
cur_order
=
min
(
i
+
1
,
self
.
order
)
coeffs
=
[
linear_multistep_coeff
(
cur_order
,
sigmas_cpu
,
i
,
j
)
for
j
in
range
(
cur_order
)
]
x
=
x
+
sum
(
coeff
*
d
for
coeff
,
d
in
zip
(
coeffs
,
reversed
(
ds
)))
return
x
class
EulerEDMSampler
(
EDMSampler
):
def
possible_correction_step
(
self
,
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
):
return
euler_step
class
HeunEDMSampler
(
EDMSampler
):
def
possible_correction_step
(
self
,
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
):
if
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0
return
euler_step
else
:
denoised
=
self
.
denoise
(
euler_step
,
denoiser
,
next_sigma
,
cond
,
uc
)
d_new
=
to_d
(
euler_step
,
next_sigma
,
denoised
)
d_prime
=
(
d
+
d_new
)
/
2.0
# apply correction if noise level is not 0
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x
+
d_prime
*
dt
,
euler_step
)
return
x
class
EulerAncestralSampler
(
AncestralSampler
):
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
):
sigma_down
,
sigma_up
=
get_ancestral_step
(
sigma
,
next_sigma
,
eta
=
self
.
eta
)
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
x
=
self
.
ancestral_euler_step
(
x
,
denoised
,
sigma
,
sigma_down
)
x
=
self
.
ancestral_step
(
x
,
sigma
,
next_sigma
,
sigma_up
)
return
x
class
DPMPP2SAncestralSampler
(
AncestralSampler
):
def
get_variables
(
self
,
sigma
,
sigma_down
):
t
,
t_next
=
(
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
sigma_down
))
h
=
t_next
-
t
s
=
t
+
0.5
*
h
return
h
,
s
,
t
,
t_next
def
get_mult
(
self
,
h
,
s
,
t
,
t_next
):
mult1
=
to_sigma
(
s
)
/
to_sigma
(
t
)
mult2
=
(
-
0.5
*
h
).
expm1
()
mult3
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
mult4
=
(
-
h
).
expm1
()
return
mult1
,
mult2
,
mult3
,
mult4
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
**
kwargs
):
sigma_down
,
sigma_up
=
get_ancestral_step
(
sigma
,
next_sigma
,
eta
=
self
.
eta
)
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
x_euler
=
self
.
ancestral_euler_step
(
x
,
denoised
,
sigma
,
sigma_down
)
if
torch
.
sum
(
sigma_down
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0
x
=
x_euler
else
:
h
,
s
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
sigma_down
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
s
,
t
,
t_next
)
]
x2
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
denoised2
=
self
.
denoise
(
x2
,
denoiser
,
to_sigma
(
s
),
cond
,
uc
)
x_dpmpp2s
=
mult
[
2
]
*
x
-
mult
[
3
]
*
denoised2
# apply correction if noise level is not 0
x
=
torch
.
where
(
append_dims
(
sigma_down
,
x
.
ndim
)
>
0.0
,
x_dpmpp2s
,
x_euler
)
x
=
self
.
ancestral_step
(
x
,
sigma
,
next_sigma
,
sigma_up
)
return
x
class
DPMPP2MSampler
(
BaseDiffusionSampler
):
def
get_variables
(
self
,
sigma
,
next_sigma
,
previous_sigma
=
None
):
t
,
t_next
=
(
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
next_sigma
))
h
=
t_next
-
t
if
previous_sigma
is
not
None
:
h_last
=
t
-
to_neg_log_sigma
(
previous_sigma
)
r
=
h_last
/
h
return
h
,
r
,
t
,
t_next
else
:
return
h
,
None
,
t
,
t_next
def
get_mult
(
self
,
h
,
r
,
t
,
t_next
,
previous_sigma
):
mult1
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
mult2
=
(
-
h
).
expm1
()
if
previous_sigma
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_sigma
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
h
,
r
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
next_sigma
,
previous_sigma
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
t
,
t_next
,
previous_sigma
)
]
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
if
old_denoised
is
None
or
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
# apply correction if noise level is not 0 and not first step
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x_advanced
,
x_standard
)
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
sigmas
[
i
-
1
],
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
)
return
x
class
SDEDPMPP2MSampler
(
BaseDiffusionSampler
):
def
get_variables
(
self
,
sigma
,
next_sigma
,
previous_sigma
=
None
):
t
,
t_next
=
(
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
next_sigma
))
h
=
t_next
-
t
if
previous_sigma
is
not
None
:
h_last
=
t
-
to_neg_log_sigma
(
previous_sigma
)
r
=
h_last
/
h
return
h
,
r
,
t
,
t_next
else
:
return
h
,
None
,
t
,
t_next
def
get_mult
(
self
,
h
,
r
,
t
,
t_next
,
previous_sigma
):
mult1
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
*
(
-
h
).
exp
()
mult2
=
(
-
2
*
h
).
expm1
()
if
previous_sigma
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_sigma
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
h
,
r
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
next_sigma
,
previous_sigma
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
t
,
t_next
,
previous_sigma
)
]
mult_noise
=
append_dims
(
next_sigma
*
(
1
-
(
-
2
*
h
).
exp
())
**
0.5
,
x
.
ndim
)
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
+
mult_noise
*
torch
.
randn_like
(
x
)
if
old_denoised
is
None
or
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
torch
.
randn_like
(
x
)
# apply correction if noise level is not 0 and not first step
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x_advanced
,
x_standard
)
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
sigmas
[
i
-
1
],
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
)
return
x
class
SdeditEDMSampler
(
EulerEDMSampler
):
def
__init__
(
self
,
edit_ratio
=
0.5
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
edit_ratio
=
edit_ratio
def
__call__
(
self
,
denoiser
,
image
,
randn
,
cond
,
uc
=
None
,
num_steps
=
None
,
edit_ratio
=
None
):
randn_unit
=
randn
.
clone
()
randn
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
randn
,
cond
,
uc
,
num_steps
)
if
num_steps
is
None
:
num_steps
=
self
.
num_steps
if
edit_ratio
is
None
:
edit_ratio
=
self
.
edit_ratio
x
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
if
i
/
num_steps
<
edit_ratio
:
continue
if
x
is
None
:
x
=
image
+
randn_unit
*
append_dims
(
s_in
*
sigmas
[
i
],
len
(
randn_unit
.
shape
))
gamma
=
(
min
(
self
.
s_churn
/
(
num_sigmas
-
1
),
2
**
0.5
-
1
)
if
self
.
s_tmin
<=
sigmas
[
i
]
<=
self
.
s_tmax
else
0.0
)
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
gamma
,
)
return
x
class
VideoDDIMSampler
(
BaseDiffusionSampler
):
def
__init__
(
self
,
fixed_frames
=
0
,
sdedit
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
fixed_frames
=
fixed_frames
self
.
sdedit
=
sdedit
def
prepare_sampling_loop
(
self
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
alpha_cumprod_sqrt
,
timesteps
=
self
.
discretization
(
self
.
num_steps
if
num_steps
is
None
else
num_steps
,
device
=
self
.
device
,
return_idx
=
True
,
do_append_zero
=
False
,
)
alpha_cumprod_sqrt
=
torch
.
cat
(
[
alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
.
new_ones
([
1
])])
timesteps
=
torch
.
cat
([
torch
.
tensor
(
list
(
timesteps
)).
new_zeros
([
1
])
-
1
,
torch
.
tensor
(
list
(
timesteps
))
])
uc
=
default
(
uc
,
cond
)
num_sigmas
=
len
(
alpha_cumprod_sqrt
)
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
return
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
def
denoise
(
self
,
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
=
None
,
idx
=
None
,
scale
=
None
,
scale_emb
=
None
):
additional_model_inputs
=
{}
if
isinstance
(
scale
,
torch
.
Tensor
)
==
False
and
scale
==
1
:
additional_model_inputs
[
'idx'
]
=
x
.
new_ones
([
x
.
shape
[
0
]
])
*
timestep
if
scale_emb
is
not
None
:
additional_model_inputs
[
'scale_emb'
]
=
scale_emb
denoised
=
denoiser
(
x
,
alpha_cumprod_sqrt
,
cond
,
**
additional_model_inputs
).
to
(
torch
.
float32
)
else
:
additional_model_inputs
[
'idx'
]
=
torch
.
cat
(
[
x
.
new_ones
([
x
.
shape
[
0
]])
*
timestep
]
*
2
)
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
alpha_cumprod_sqrt
,
cond
,
uc
),
**
additional_model_inputs
).
to
(
torch
.
float32
)
# assert denoised.isnan().sum() < 1, f'detect nan at {timestep}'
if
isinstance
(
self
.
guider
,
DynamicCFG
):
denoised
=
self
.
guider
(
denoised
,
(
1
-
alpha_cumprod_sqrt
**
2
)
**
0.5
,
step_index
=
self
.
num_steps
-
timestep
,
scale
=
scale
)
# denoised = self.guider(
# denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=torch.tensor(self.num_steps - idx), scale=scale
# )
else
:
denoised
=
self
.
guider
(
denoised
,
(
1
-
alpha_cumprod_sqrt
**
2
)
**
0.5
,
scale
=
scale
)
return
denoised
def
sampler_step
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
scale
=
None
,
scale_emb
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
,
scale
=
scale
,
scale_emb
=
scale_emb
).
to
(
torch
.
float32
)
a_t
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
b_t
=
next_alpha_cumprod_sqrt
-
alpha_cumprod_sqrt
*
a_t
x
=
append_dims
(
a_t
,
x
.
ndim
)
*
x
+
append_dims
(
b_t
,
x
.
ndim
)
*
denoised
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
scale_emb
=
None
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
scale
=
scale
,
scale_emb
=
scale_emb
,
)
return
x
class
VPSDEDPMPP2MSampler
(
VideoDDIMSampler
):
def
get_variables
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
=
None
):
alpha_cumprod
=
alpha_cumprod_sqrt
**
2
lamb
=
((
alpha_cumprod
/
(
1
-
alpha_cumprod
))
**
0.5
).
log
()
next_alpha_cumprod
=
next_alpha_cumprod_sqrt
**
2
lamb_next
=
((
next_alpha_cumprod
/
(
1
-
next_alpha_cumprod
))
**
0.5
).
log
()
h
=
lamb_next
-
lamb
if
previous_alpha_cumprod_sqrt
is
not
None
:
previous_alpha_cumprod
=
previous_alpha_cumprod_sqrt
**
2
lamb_previous
=
((
previous_alpha_cumprod
/
(
1
-
previous_alpha_cumprod
))
**
0.5
).
log
()
h_last
=
lamb
-
lamb_previous
r
=
h_last
/
h
return
h
,
r
,
lamb
,
lamb_next
else
:
return
h
,
None
,
lamb
,
lamb_next
def
get_mult
(
self
,
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
):
mult1
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
*
(
-
h
).
exp
()
mult2
=
(
-
2
*
h
).
expm1
()
*
next_alpha_cumprod_sqrt
if
previous_alpha_cumprod_sqrt
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
scale
=
None
,
scale_emb
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
,
scale
=
scale
,
scale_emb
=
scale_emb
).
to
(
torch
.
float32
)
if
idx
==
1
:
return
denoised
,
denoised
# assert denoised.isnan().sum() < 1, f' nan is detected at {idx}'
h
,
r
,
lamb
,
lamb_next
=
self
.
get_variables
(
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
]
mult_noise
=
append_dims
(
(
1
-
next_alpha_cumprod_sqrt
**
2
)
**
0.5
*
(
1
-
(
-
2
*
h
).
exp
())
**
0.5
,
x
.
ndim
)
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
+
mult_noise
*
torch
.
randn_like
(
x
)
if
old_denoised
is
None
or
torch
.
sum
(
next_alpha_cumprod_sqrt
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
torch
.
randn_like
(
x
)
x
=
x_advanced
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
scale_emb
=
None
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
if
self
.
fixed_frames
>
0
:
prefix_frames
=
x
[:,
:
self
.
fixed_frames
]
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
if
self
.
fixed_frames
>
0
:
if
self
.
sdedit
:
rd
=
torch
.
randn_like
(
prefix_frames
)
noised_prefix_frames
=
alpha_cumprod_sqrt
[
i
]
*
prefix_frames
+
rd
*
append_dims
(
s_in
*
(
1
-
alpha_cumprod_sqrt
[
i
]
**
2
)
**
0.5
,
len
(
prefix_frames
.
shape
))
x
=
torch
.
cat
(
[
noised_prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
else
:
x
=
torch
.
cat
([
prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
alpha_cumprod_sqrt
[
i
-
1
],
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
scale
=
scale
,
scale_emb
=
scale_emb
,
)
if
self
.
fixed_frames
>
0
:
x
=
torch
.
cat
([
prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
return
x
class
CascadeVPSDEDPMPP2MSampler
(
VPSDEDPMPP2MSampler
):
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
scale_emb
=
None
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
ref_x
=
self
.
ref_x
ref_noise_step
=
self
.
ref_noise_step
timesteps
=
timesteps
[
timesteps
<=
ref_noise_step
]
alpha_cumprod_sqrt
=
alpha_cumprod_sqrt
[
-
len
(
timesteps
):]
num_sigmas
=
len
(
timesteps
)
old_denoised
=
None
x
=
ref_x
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
alpha_cumprod_sqrt
[
i
-
1
],
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
idx
=
num_sigmas
-
1
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
scale
=
scale
,
scale_emb
=
scale_emb
,
)
return
x
class
VPODEDPMPP2MSampler
(
VideoDDIMSampler
):
def
get_variables
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
=
None
):
alpha_cumprod
=
alpha_cumprod_sqrt
**
2
lamb
=
((
alpha_cumprod
/
(
1
-
alpha_cumprod
))
**
0.5
).
log
()
next_alpha_cumprod
=
next_alpha_cumprod_sqrt
**
2
lamb_next
=
((
next_alpha_cumprod
/
(
1
-
next_alpha_cumprod
))
**
0.5
).
log
()
h
=
lamb_next
-
lamb
if
previous_alpha_cumprod_sqrt
is
not
None
:
previous_alpha_cumprod
=
previous_alpha_cumprod_sqrt
**
2
lamb_previous
=
((
previous_alpha_cumprod
/
(
1
-
previous_alpha_cumprod
))
**
0.5
).
log
()
h_last
=
lamb
-
lamb_previous
r
=
h_last
/
h
return
h
,
r
,
lamb
,
lamb_next
else
:
return
h
,
None
,
lamb
,
lamb_next
def
get_mult
(
self
,
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
):
mult1
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
mult2
=
(
-
h
).
expm1
()
*
next_alpha_cumprod_sqrt
if
previous_alpha_cumprod_sqrt
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
).
to
(
torch
.
float32
)
if
idx
==
1
:
return
denoised
,
denoised
h
,
r
,
lamb
,
lamb_next
=
self
.
get_variables
(
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
]
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
if
old_denoised
is
None
or
torch
.
sum
(
next_alpha_cumprod_sqrt
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
x
=
x_advanced
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
**
kwargs
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
alpha_cumprod_sqrt
[
i
-
1
],
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
)
return
x
flashvideo/sgm/modules/diffusionmodules/sampling_utils.py
0 → 100644
View file @
3b804999
import
torch
from
einops
import
rearrange
from
scipy
import
integrate
from
...util
import
append_dims
class
NoDynamicThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
scale
=
append_dims
(
scale
,
cond
.
ndim
)
if
isinstance
(
scale
,
torch
.
Tensor
)
else
scale
return
uncond
+
scale
*
(
cond
-
uncond
)
class
StaticThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
result
=
uncond
+
scale
*
(
cond
-
uncond
)
result
=
torch
.
clamp
(
result
,
min
=-
1.0
,
max
=
1.0
)
return
result
def
dynamic_threshold
(
x
,
p
=
0.95
):
N
,
T
,
C
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
'n t c h w -> n c (t h w)'
)
l
,
r
=
x
.
quantile
(
q
=
torch
.
tensor
([
1
-
p
,
p
],
device
=
x
.
device
),
dim
=-
1
,
keepdim
=
True
)
s
=
torch
.
maximum
(
-
l
,
r
)
threshold_mask
=
(
s
>
1
).
expand
(
-
1
,
-
1
,
H
*
W
*
T
)
if
threshold_mask
.
any
():
x
=
torch
.
where
(
threshold_mask
,
x
.
clamp
(
min
=-
1
*
s
,
max
=
s
),
x
)
x
=
rearrange
(
x
,
'n c (t h w) -> n t c h w'
,
t
=
T
,
h
=
H
,
w
=
W
)
return
x
def
dynamic_thresholding2
(
x0
):
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
append_dims
(
torch
.
maximum
(
s
,
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
# / s
return
x0
.
to
(
origin_dtype
)
def
latent_dynamic_thresholding
(
x0
):
p
=
0.9995
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
),
p
,
dim
=
2
)
s
=
append_dims
(
s
,
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
return
x0
.
to
(
origin_dtype
)
def
dynamic_thresholding3
(
x0
):
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
append_dims
(
torch
.
maximum
(
s
,
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
# / s
return
x0
.
to
(
origin_dtype
)
class
DynamicThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
mean
=
uncond
.
mean
()
std
=
uncond
.
std
()
result
=
uncond
+
scale
*
(
cond
-
uncond
)
result_mean
,
result_std
=
result
.
mean
(),
result
.
std
()
result
=
(
result
-
result_mean
)
/
result_std
*
std
# result = dynamic_thresholding3(result)
return
result
class
DynamicThresholdingV1
:
def
__init__
(
self
,
scale_factor
):
self
.
scale_factor
=
scale_factor
def
__call__
(
self
,
uncond
,
cond
,
scale
):
result
=
uncond
+
scale
*
(
cond
-
uncond
)
unscaled_result
=
result
/
self
.
scale_factor
B
,
T
,
C
,
H
,
W
=
unscaled_result
.
shape
flattened
=
rearrange
(
unscaled_result
,
'b t c h w -> b c (t h w)'
)
means
=
flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
recentered
=
flattened
-
means
magnitudes
=
recentered
.
abs
().
max
()
normalized
=
recentered
/
magnitudes
thresholded
=
latent_dynamic_thresholding
(
normalized
)
denormalized
=
thresholded
*
magnitudes
uncentered
=
denormalized
+
means
unflattened
=
rearrange
(
uncentered
,
'b c (t h w) -> b t c h w'
,
t
=
T
,
h
=
H
,
w
=
W
)
scaled_result
=
unflattened
*
self
.
scale_factor
return
scaled_result
class
DynamicThresholdingV2
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
B
,
T
,
C
,
H
,
W
=
uncond
.
shape
diff
=
cond
-
uncond
mim_target
=
uncond
+
diff
*
4.0
cfg_target
=
uncond
+
diff
*
8.0
mim_flattened
=
rearrange
(
mim_target
,
'b t c h w -> b c (t h w)'
)
cfg_flattened
=
rearrange
(
cfg_target
,
'b t c h w -> b c (t h w)'
)
mim_means
=
mim_flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
cfg_means
=
cfg_flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
mim_centered
=
mim_flattened
-
mim_means
cfg_centered
=
cfg_flattened
-
cfg_means
mim_scaleref
=
mim_centered
.
std
(
dim
=
2
).
unsqueeze
(
2
)
cfg_scaleref
=
cfg_centered
.
std
(
dim
=
2
).
unsqueeze
(
2
)
cfg_renormalized
=
cfg_centered
/
cfg_scaleref
*
mim_scaleref
result
=
cfg_renormalized
+
cfg_means
unflattened
=
rearrange
(
result
,
'b c (t h w) -> b t c h w'
,
t
=
T
,
h
=
H
,
w
=
W
)
return
unflattened
def
linear_multistep_coeff
(
order
,
t
,
i
,
j
,
epsrel
=
1e-4
):
if
order
-
1
>
i
:
raise
ValueError
(
f
'Order
{
order
}
too high for step
{
i
}
'
)
def
fn
(
tau
):
prod
=
1.0
for
k
in
range
(
order
):
if
j
==
k
:
continue
prod
*=
(
tau
-
t
[
i
-
k
])
/
(
t
[
i
-
j
]
-
t
[
i
-
k
])
return
prod
return
integrate
.
quad
(
fn
,
t
[
i
],
t
[
i
+
1
],
epsrel
=
epsrel
)[
0
]
def
get_ancestral_step
(
sigma_from
,
sigma_to
,
eta
=
1.0
):
if
not
eta
:
return
sigma_to
,
0.0
sigma_up
=
torch
.
minimum
(
sigma_to
,
eta
*
(
sigma_to
**
2
*
(
sigma_from
**
2
-
sigma_to
**
2
)
/
sigma_from
**
2
)
**
0.5
,
)
sigma_down
=
(
sigma_to
**
2
-
sigma_up
**
2
)
**
0.5
return
sigma_down
,
sigma_up
def
to_d
(
x
,
sigma
,
denoised
):
return
(
x
-
denoised
)
/
append_dims
(
sigma
,
x
.
ndim
)
def
to_neg_log_sigma
(
sigma
):
return
sigma
.
log
().
neg
()
def
to_sigma
(
neg_log_sigma
):
return
neg_log_sigma
.
neg
().
exp
()
flashvideo/sgm/modules/diffusionmodules/sigma_sampling.py
0 → 100644
View file @
3b804999
import
torch
import
torch.distributed
from
sat
import
mpu
from
...util
import
default
,
instantiate_from_config
class
EDMSampling
:
def
__init__
(
self
,
p_mean
=-
1.2
,
p_std
=
1.2
):
self
.
p_mean
=
p_mean
self
.
p_std
=
p_std
def
__call__
(
self
,
n_samples
,
rand
=
None
):
log_sigma
=
self
.
p_mean
+
self
.
p_std
*
default
(
rand
,
torch
.
randn
((
n_samples
,
)))
return
log_sigma
.
exp
()
class
DiscreteSampling
:
def
__init__
(
self
,
discretization_config
,
num_idx
,
do_append_zero
=
False
,
flip
=
True
,
uniform_sampling
=
False
):
self
.
num_idx
=
num_idx
self
.
sigmas
=
instantiate_from_config
(
discretization_config
)(
num_idx
,
do_append_zero
=
do_append_zero
,
flip
=
flip
)
world_size
=
mpu
.
get_data_parallel_world_size
()
self
.
uniform_sampling
=
uniform_sampling
if
self
.
uniform_sampling
:
i
=
1
while
True
:
if
world_size
%
i
!=
0
or
num_idx
%
(
world_size
//
i
)
!=
0
:
i
+=
1
else
:
self
.
group_num
=
world_size
//
i
break
assert
self
.
group_num
>
0
assert
world_size
%
self
.
group_num
==
0
self
.
group_width
=
world_size
//
self
.
group_num
# the number of rank in one group
self
.
sigma_interval
=
self
.
num_idx
//
self
.
group_num
def
idx_to_sigma
(
self
,
idx
):
return
self
.
sigmas
[
idx
]
def
__call__
(
self
,
n_samples
,
rand
=
None
,
return_idx
=
False
):
if
self
.
uniform_sampling
:
rank
=
mpu
.
get_data_parallel_rank
()
group_index
=
rank
//
self
.
group_width
idx
=
default
(
rand
,
torch
.
randint
(
group_index
*
self
.
sigma_interval
,
(
group_index
+
1
)
*
self
.
sigma_interval
,
(
n_samples
,
)),
)
else
:
idx
=
default
(
rand
,
torch
.
randint
(
0
,
self
.
num_idx
,
(
n_samples
,
)),
)
if
return_idx
:
return
self
.
idx_to_sigma
(
idx
),
idx
else
:
return
self
.
idx_to_sigma
(
idx
)
class
PartialDiscreteSampling
:
def
__init__
(
self
,
discretization_config
,
total_num_idx
,
partial_num_idx
,
do_append_zero
=
False
,
flip
=
True
):
self
.
total_num_idx
=
total_num_idx
self
.
partial_num_idx
=
partial_num_idx
self
.
sigmas
=
instantiate_from_config
(
discretization_config
)(
total_num_idx
,
do_append_zero
=
do_append_zero
,
flip
=
flip
)
def
idx_to_sigma
(
self
,
idx
):
return
self
.
sigmas
[
idx
]
def
__call__
(
self
,
n_samples
,
rand
=
None
):
idx
=
default
(
rand
,
# torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
torch
.
randint
(
0
,
self
.
partial_num_idx
,
(
n_samples
,
)),
)
return
self
.
idx_to_sigma
(
idx
)
flashvideo/sgm/modules/diffusionmodules/util.py
0 → 100644
View file @
3b804999
"""
adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
and
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
thanks!
"""
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
):
if
schedule
==
'linear'
:
betas
=
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
return
betas
.
numpy
()
def
extract_into_tensor
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,
)
*
(
len
(
x_shape
)
-
1
)))
def
mixed_checkpoint
(
func
,
inputs
:
dict
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
it also works with non-tensor inputs
:param func: the function to evaluate.
:param inputs: the argument dictionary to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
tensor_keys
=
[
key
for
key
in
inputs
if
isinstance
(
inputs
[
key
],
torch
.
Tensor
)
]
tensor_inputs
=
[
inputs
[
key
]
for
key
in
inputs
if
isinstance
(
inputs
[
key
],
torch
.
Tensor
)
]
non_tensor_keys
=
[
key
for
key
in
inputs
if
not
isinstance
(
inputs
[
key
],
torch
.
Tensor
)
]
non_tensor_inputs
=
[
inputs
[
key
]
for
key
in
inputs
if
not
isinstance
(
inputs
[
key
],
torch
.
Tensor
)
]
args
=
tuple
(
tensor_inputs
)
+
tuple
(
non_tensor_inputs
)
+
tuple
(
params
)
return
MixedCheckpointFunction
.
apply
(
func
,
len
(
tensor_inputs
),
len
(
non_tensor_inputs
),
tensor_keys
,
non_tensor_keys
,
*
args
,
)
else
:
return
func
(
**
inputs
)
class
MixedCheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length_tensors
,
length_non_tensors
,
tensor_keys
,
non_tensor_keys
,
*
args
,
):
ctx
.
end_tensors
=
length_tensors
ctx
.
end_non_tensors
=
length_tensors
+
length_non_tensors
ctx
.
gpu_autocast_kwargs
=
{
'enabled'
:
torch
.
is_autocast_enabled
(),
'dtype'
:
torch
.
get_autocast_gpu_dtype
(),
'cache_enabled'
:
torch
.
is_autocast_cache_enabled
(),
}
assert
len
(
tensor_keys
)
==
length_tensors
and
len
(
non_tensor_keys
)
==
length_non_tensors
ctx
.
input_tensors
=
{
key
:
val
for
(
key
,
val
)
in
zip
(
tensor_keys
,
list
(
args
[:
ctx
.
end_tensors
]))
}
ctx
.
input_non_tensors
=
{
key
:
val
for
(
key
,
val
)
in
zip
(
non_tensor_keys
,
list
(
args
[
ctx
.
end_tensors
:
ctx
.
end_non_tensors
]))
}
ctx
.
run_function
=
run_function
ctx
.
input_params
=
list
(
args
[
ctx
.
end_non_tensors
:])
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
**
ctx
.
input_tensors
,
**
ctx
.
input_non_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
ctx
.
input_tensors
=
{
key
:
ctx
.
input_tensors
[
key
].
detach
().
requires_grad_
(
True
)
for
key
in
ctx
.
input_tensors
}
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
{
key
:
ctx
.
input_tensors
[
key
].
view_as
(
ctx
.
input_tensors
[
key
])
for
key
in
ctx
.
input_tensors
}
# shallow_copies.update(additional_args)
output_tensors
=
ctx
.
run_function
(
**
shallow_copies
,
**
ctx
.
input_non_tensors
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
list
(
ctx
.
input_tensors
.
values
())
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
((
None
,
None
,
None
,
None
,
None
)
+
input_grads
[:
ctx
.
end_tensors
]
+
(
None
,
)
*
(
ctx
.
end_non_tensors
-
ctx
.
end_tensors
)
+
input_grads
[
ctx
.
end_tensors
:])
def
checkpoint
(
func
,
inputs
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
else
:
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
gpu_autocast_kwargs
=
{
'enabled'
:
torch
.
is_autocast_enabled
(),
'dtype'
:
torch
.
get_autocast_gpu_dtype
(),
'cache_enabled'
:
torch
.
is_autocast_cache_enabled
(),
}
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
ctx
.
input_tensors
=
[
x
.
detach
().
requires_grad_
(
True
)
for
x
in
ctx
.
input_tensors
]
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
ctx
.
input_tensors
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
(
None
,
None
)
+
input_grads
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
,
repeat_only
=
False
,
dtype
=
torch
.
float32
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
else
:
embedding
=
repeat
(
timesteps
,
'b -> b d'
,
d
=
dim
)
return
embedding
.
to
(
dtype
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
mean_flat
(
tensor
):
"""
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
32
,
channels
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
'unsupported dimensions:
{
dims
}
'
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
'unsupported dimensions:
{
dims
}
'
)
class
AlphaBlender
(
nn
.
Module
):
strategies
=
[
'learned'
,
'fixed'
,
'learned_with_images'
]
def
__init__
(
self
,
alpha
:
float
,
merge_strategy
:
str
=
'learned_with_images'
,
rearrange_pattern
:
str
=
'b t -> (b t) 1 1'
,
):
super
().
__init__
()
self
.
merge_strategy
=
merge_strategy
self
.
rearrange_pattern
=
rearrange_pattern
assert
merge_strategy
in
self
.
strategies
,
f
'merge_strategy needs to be in
{
self
.
strategies
}
'
if
self
.
merge_strategy
==
'fixed'
:
self
.
register_buffer
(
'mix_factor'
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
'learned'
or
self
.
merge_strategy
==
'learned_with_images'
:
self
.
register_parameter
(
'mix_factor'
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
])))
else
:
raise
ValueError
(
f
'unknown merge strategy
{
self
.
merge_strategy
}
'
)
def
get_alpha
(
self
,
image_only_indicator
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
merge_strategy
==
'fixed'
:
alpha
=
self
.
mix_factor
elif
self
.
merge_strategy
==
'learned'
:
alpha
=
torch
.
sigmoid
(
self
.
mix_factor
)
elif
self
.
merge_strategy
==
'learned_with_images'
:
assert
image_only_indicator
is
not
None
,
'need image_only_indicator ...'
alpha
=
torch
.
where
(
image_only_indicator
.
bool
(),
torch
.
ones
(
1
,
1
,
device
=
image_only_indicator
.
device
),
rearrange
(
torch
.
sigmoid
(
self
.
mix_factor
),
'... -> ... 1'
),
)
alpha
=
rearrange
(
alpha
,
self
.
rearrange_pattern
)
else
:
raise
NotImplementedError
return
alpha
def
forward
(
self
,
x_spatial
:
torch
.
Tensor
,
x_temporal
:
torch
.
Tensor
,
image_only_indicator
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
alpha
=
self
.
get_alpha
(
image_only_indicator
)
x
=
alpha
.
to
(
x_spatial
.
dtype
)
*
x_spatial
+
(
1.0
-
alpha
).
to
(
x_spatial
.
dtype
)
*
x_temporal
return
x
flashvideo/sgm/modules/diffusionmodules/wrappers.py
0 → 100644
View file @
3b804999
import
torch
import
torch.nn
as
nn
from
packaging
import
version
OPENAIUNETWRAPPER
=
'sgm.modules.diffusionmodules.wrappers.OpenAIWrapper'
class
IdentityWrapper
(
nn
.
Module
):
def
__init__
(
self
,
diffusion_model
,
compile_model
:
bool
=
False
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
super
().
__init__
()
compile
=
(
torch
.
compile
if
(
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
))
and
compile_model
else
lambda
x
:
x
)
self
.
diffusion_model
=
compile
(
diffusion_model
)
self
.
dtype
=
dtype
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
diffusion_model
(
*
args
,
**
kwargs
)
class
OpenAIWrapper
(
IdentityWrapper
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
c
:
dict
,
**
kwargs
)
->
torch
.
Tensor
:
for
key
in
c
:
c
[
key
]
=
c
[
key
].
to
(
self
.
dtype
)
if
x
.
dim
()
==
4
:
x
=
torch
.
cat
((
x
,
c
.
get
(
'concat'
,
torch
.
Tensor
([]).
type_as
(
x
))),
dim
=
1
)
elif
x
.
dim
()
==
5
:
x
=
torch
.
cat
((
x
,
c
.
get
(
'concat'
,
torch
.
Tensor
([]).
type_as
(
x
))),
dim
=
2
)
else
:
raise
ValueError
(
'Input tensor must be 4D or 5D'
)
return
self
.
diffusion_model
(
x
,
timesteps
=
t
,
context
=
c
.
get
(
'crossattn'
,
None
),
y
=
c
.
get
(
'vector'
,
None
),
**
kwargs
,
)
flashvideo/sgm/modules/distributions/__init__.py
0 → 100644
View file @
3b804999
flashvideo/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment