Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4840 additions
and
0 deletions
+4840
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/lora.py
cogvideox-based/sat/sgm/modules/diffusionmodules/lora.py
+362
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/loss.py
cogvideox-based/sat/sgm/modules/diffusionmodules/loss.py
+279
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/model.py
cogvideox-based/sat/sgm/modules/diffusionmodules/model.py
+683
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/openaimodel.py
...eox-based/sat/sgm/modules/diffusionmodules/openaimodel.py
+1249
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/sampling.py
cogvideox-based/sat/sgm/modules/diffusionmodules/sampling.py
+773
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/sampling_utils.py
...-based/sat/sgm/modules/diffusionmodules/sampling_utils.py
+155
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/sigma_sampling.py
...-based/sat/sgm/modules/diffusionmodules/sigma_sampling.py
+80
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/util.py
cogvideox-based/sat/sgm/modules/diffusionmodules/util.py
+328
-0
cogvideox-based/sat/sgm/modules/diffusionmodules/wrappers.py
cogvideox-based/sat/sgm/modules/diffusionmodules/wrappers.py
+41
-0
cogvideox-based/sat/sgm/modules/distributions/__init__.py
cogvideox-based/sat/sgm/modules/distributions/__init__.py
+0
-0
cogvideox-based/sat/sgm/modules/distributions/__pycache__/__init__.cpython-39.pyc
...modules/distributions/__pycache__/__init__.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/distributions/__pycache__/distributions.cpython-39.pyc
...es/distributions/__pycache__/distributions.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/distributions/distributions.py
...deox-based/sat/sgm/modules/distributions/distributions.py
+94
-0
cogvideox-based/sat/sgm/modules/ema.py
cogvideox-based/sat/sgm/modules/ema.py
+82
-0
cogvideox-based/sat/sgm/modules/encoders/__init__.py
cogvideox-based/sat/sgm/modules/encoders/__init__.py
+0
-0
cogvideox-based/sat/sgm/modules/encoders/__pycache__/__init__.cpython-39.pyc
.../sgm/modules/encoders/__pycache__/__init__.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/encoders/__pycache__/modules.cpython-39.pyc
...t/sgm/modules/encoders/__pycache__/modules.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/encoders/modules.py
cogvideox-based/sat/sgm/modules/encoders/modules.py
+281
-0
cogvideox-based/sat/sgm/modules/fuse_sft_block.py
cogvideox-based/sat/sgm/modules/fuse_sft_block.py
+140
-0
cogvideox-based/sat/sgm/modules/video_attention.py
cogvideox-based/sat/sgm/modules/video_attention.py
+293
-0
No files found.
cogvideox-based/sat/sgm/modules/diffusionmodules/lora.py
0 → 100644
View file @
1f5da520
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
LoRALinearLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
rank
=
4
,
network_alpha
=
None
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
self
.
down
=
nn
.
Linear
(
in_features
,
rank
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
)
self
.
up
=
nn
.
Linear
(
rank
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self
.
network_alpha
=
network_alpha
self
.
rank
=
rank
self
.
out_features
=
out_features
self
.
in_features
=
in_features
nn
.
init
.
normal_
(
self
.
down
.
weight
,
std
=
1
/
rank
)
nn
.
init
.
zeros_
(
self
.
up
.
weight
)
def
forward
(
self
,
hidden_states
):
orig_dtype
=
hidden_states
.
dtype
dtype
=
self
.
down
.
weight
.
dtype
down_hidden_states
=
self
.
down
(
hidden_states
.
to
(
dtype
))
up_hidden_states
=
self
.
up
(
down_hidden_states
)
if
self
.
network_alpha
is
not
None
:
up_hidden_states
*=
self
.
network_alpha
/
self
.
rank
return
up_hidden_states
.
to
(
orig_dtype
)
class
LoRAConv2dLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
rank
=
4
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
padding
=
0
,
network_alpha
=
None
):
super
().
__init__
()
self
.
down
=
nn
.
Conv2d
(
in_features
,
rank
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self
.
up
=
nn
.
Conv2d
(
rank
,
out_features
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self
.
network_alpha
=
network_alpha
self
.
rank
=
rank
nn
.
init
.
normal_
(
self
.
down
.
weight
,
std
=
1
/
rank
)
nn
.
init
.
zeros_
(
self
.
up
.
weight
)
def
forward
(
self
,
hidden_states
):
orig_dtype
=
hidden_states
.
dtype
dtype
=
self
.
down
.
weight
.
dtype
down_hidden_states
=
self
.
down
(
hidden_states
.
to
(
dtype
))
up_hidden_states
=
self
.
up
(
down_hidden_states
)
if
self
.
network_alpha
is
not
None
:
up_hidden_states
*=
self
.
network_alpha
/
self
.
rank
return
up_hidden_states
.
to
(
orig_dtype
)
class
LoRACompatibleConv
(
nn
.
Conv2d
):
"""
A convolutional layer that can be used with LoRA.
"""
def
__init__
(
self
,
*
args
,
lora_layer
:
Optional
[
LoRAConv2dLayer
]
=
None
,
scale
:
float
=
1.0
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
lora_layer
=
lora_layer
self
.
scale
=
scale
def
set_lora_layer
(
self
,
lora_layer
:
Optional
[
LoRAConv2dLayer
]):
self
.
lora_layer
=
lora_layer
def
_fuse_lora
(
self
,
lora_scale
=
1.0
):
if
self
.
lora_layer
is
None
:
return
dtype
,
device
=
self
.
weight
.
data
.
dtype
,
self
.
weight
.
data
.
device
w_orig
=
self
.
weight
.
data
.
float
()
w_up
=
self
.
lora_layer
.
up
.
weight
.
data
.
float
()
w_down
=
self
.
lora_layer
.
down
.
weight
.
data
.
float
()
if
self
.
lora_layer
.
network_alpha
is
not
None
:
w_up
=
w_up
*
self
.
lora_layer
.
network_alpha
/
self
.
lora_layer
.
rank
fusion
=
torch
.
mm
(
w_up
.
flatten
(
start_dim
=
1
),
w_down
.
flatten
(
start_dim
=
1
))
fusion
=
fusion
.
reshape
((
w_orig
.
shape
))
fused_weight
=
w_orig
+
(
lora_scale
*
fusion
)
self
.
weight
.
data
=
fused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
# we can drop the lora layer now
self
.
lora_layer
=
None
# offload the up and down matrices to CPU to not blow the memory
self
.
w_up
=
w_up
.
cpu
()
self
.
w_down
=
w_down
.
cpu
()
self
.
_lora_scale
=
lora_scale
def
_unfuse_lora
(
self
):
if
not
(
hasattr
(
self
,
"w_up"
)
and
hasattr
(
self
,
"w_down"
)):
return
fused_weight
=
self
.
weight
.
data
dtype
,
device
=
fused_weight
.
data
.
dtype
,
fused_weight
.
data
.
device
self
.
w_up
=
self
.
w_up
.
to
(
device
=
device
).
float
()
self
.
w_down
=
self
.
w_down
.
to
(
device
).
float
()
fusion
=
torch
.
mm
(
self
.
w_up
.
flatten
(
start_dim
=
1
),
self
.
w_down
.
flatten
(
start_dim
=
1
))
fusion
=
fusion
.
reshape
((
fused_weight
.
shape
))
unfused_weight
=
fused_weight
.
float
()
-
(
self
.
_lora_scale
*
fusion
)
self
.
weight
.
data
=
unfused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
w_up
=
None
self
.
w_down
=
None
def
forward
(
self
,
hidden_states
,
scale
:
float
=
None
):
if
scale
is
None
:
scale
=
self
.
scale
if
self
.
lora_layer
is
None
:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return
F
.
conv2d
(
hidden_states
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
else
:
return
super
().
forward
(
hidden_states
)
+
(
scale
*
self
.
lora_layer
(
hidden_states
))
class
LoRACompatibleLinear
(
nn
.
Linear
):
"""
A Linear layer that can be used with LoRA.
"""
def
__init__
(
self
,
*
args
,
lora_layer
:
Optional
[
LoRALinearLayer
]
=
None
,
scale
:
float
=
1.0
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
lora_layer
=
lora_layer
self
.
scale
=
scale
def
set_lora_layer
(
self
,
lora_layer
:
Optional
[
LoRALinearLayer
]):
self
.
lora_layer
=
lora_layer
def
_fuse_lora
(
self
,
lora_scale
=
1.0
):
if
self
.
lora_layer
is
None
:
return
dtype
,
device
=
self
.
weight
.
data
.
dtype
,
self
.
weight
.
data
.
device
w_orig
=
self
.
weight
.
data
.
float
()
w_up
=
self
.
lora_layer
.
up
.
weight
.
data
.
float
()
w_down
=
self
.
lora_layer
.
down
.
weight
.
data
.
float
()
if
self
.
lora_layer
.
network_alpha
is
not
None
:
w_up
=
w_up
*
self
.
lora_layer
.
network_alpha
/
self
.
lora_layer
.
rank
fused_weight
=
w_orig
+
(
lora_scale
*
torch
.
bmm
(
w_up
[
None
,
:],
w_down
[
None
,
:])[
0
])
self
.
weight
.
data
=
fused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
# we can drop the lora layer now
self
.
lora_layer
=
None
# offload the up and down matrices to CPU to not blow the memory
self
.
w_up
=
w_up
.
cpu
()
self
.
w_down
=
w_down
.
cpu
()
self
.
_lora_scale
=
lora_scale
def
_unfuse_lora
(
self
):
if
not
(
hasattr
(
self
,
"w_up"
)
and
hasattr
(
self
,
"w_down"
)):
return
fused_weight
=
self
.
weight
.
data
dtype
,
device
=
fused_weight
.
dtype
,
fused_weight
.
device
w_up
=
self
.
w_up
.
to
(
device
=
device
).
float
()
w_down
=
self
.
w_down
.
to
(
device
).
float
()
unfused_weight
=
fused_weight
.
float
()
-
(
self
.
_lora_scale
*
torch
.
bmm
(
w_up
[
None
,
:],
w_down
[
None
,
:])[
0
])
self
.
weight
.
data
=
unfused_weight
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
w_up
=
None
self
.
w_down
=
None
def
forward
(
self
,
hidden_states
,
scale
:
float
=
None
):
if
scale
is
None
:
scale
=
self
.
scale
if
self
.
lora_layer
is
None
:
out
=
super
().
forward
(
hidden_states
)
return
out
else
:
out
=
super
().
forward
(
hidden_states
)
+
(
scale
*
self
.
lora_layer
(
hidden_states
))
return
out
def
_find_children
(
model
,
search_class
:
List
[
Type
[
nn
.
Module
]]
=
[
nn
.
Linear
],
):
"""
Find all modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
"""
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for
parent
in
model
.
modules
():
for
name
,
module
in
parent
.
named_children
():
if
any
([
isinstance
(
module
,
_class
)
for
_class
in
search_class
]):
yield
parent
,
name
,
module
def
_find_modules_v2
(
model
,
ancestor_class
:
Optional
[
Set
[
str
]]
=
None
,
search_class
:
List
[
Type
[
nn
.
Module
]]
=
[
nn
.
Linear
],
exclude_children_of
:
Optional
[
List
[
Type
[
nn
.
Module
]]]
=
[
LoRACompatibleLinear
,
LoRACompatibleConv
,
LoRALinearLayer
,
LoRAConv2dLayer
,
],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
indirect descendants of other modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
"""
# Get the targets we should replace all linears under
if
ancestor_class
is
not
None
:
ancestors
=
(
module
for
module
in
model
.
modules
()
if
module
.
__class__
.
__name__
in
ancestor_class
)
else
:
# this, incase you want to naively iterate over all modules.
ancestors
=
[
module
for
module
in
model
.
modules
()]
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for
ancestor
in
ancestors
:
for
fullname
,
module
in
ancestor
.
named_modules
():
if
any
([
isinstance
(
module
,
_class
)
for
_class
in
search_class
]):
# Find the direct parent if this is a descendant, not a child, of target
*
path
,
name
=
fullname
.
split
(
"."
)
parent
=
ancestor
flag
=
False
while
path
:
try
:
parent
=
parent
.
get_submodule
(
path
.
pop
(
0
))
except
:
flag
=
True
break
if
flag
:
continue
# Skip this linear if it's a child of a LoraInjectedLinear
if
exclude_children_of
and
any
([
isinstance
(
parent
,
_class
)
for
_class
in
exclude_children_of
]):
continue
# Otherwise, yield it
yield
parent
,
name
,
module
_find_modules
=
_find_modules_v2
def
inject_trainable_lora_extended
(
model
:
nn
.
Module
,
target_replace_module
:
Set
[
str
]
=
None
,
rank
:
int
=
4
,
scale
:
float
=
1.0
,
):
for
_module
,
name
,
_child_module
in
_find_modules
(
model
,
target_replace_module
,
search_class
=
[
nn
.
Linear
,
nn
.
Conv2d
]
):
if
_child_module
.
__class__
==
nn
.
Linear
:
weight
=
_child_module
.
weight
bias
=
_child_module
.
bias
lora_layer
=
LoRALinearLayer
(
in_features
=
_child_module
.
in_features
,
out_features
=
_child_module
.
out_features
,
rank
=
rank
,
)
_tmp
=
(
LoRACompatibleLinear
(
_child_module
.
in_features
,
_child_module
.
out_features
,
lora_layer
=
lora_layer
,
scale
=
scale
,
)
.
to
(
weight
.
dtype
)
.
to
(
weight
.
device
)
)
_tmp
.
weight
=
weight
if
bias
is
not
None
:
_tmp
.
bias
=
bias
elif
_child_module
.
__class__
==
nn
.
Conv2d
:
weight
=
_child_module
.
weight
bias
=
_child_module
.
bias
lora_layer
=
LoRAConv2dLayer
(
in_features
=
_child_module
.
in_channels
,
out_features
=
_child_module
.
out_channels
,
rank
=
rank
,
kernel_size
=
_child_module
.
kernel_size
,
stride
=
_child_module
.
stride
,
padding
=
_child_module
.
padding
,
)
_tmp
=
(
LoRACompatibleConv
(
_child_module
.
in_channels
,
_child_module
.
out_channels
,
kernel_size
=
_child_module
.
kernel_size
,
stride
=
_child_module
.
stride
,
padding
=
_child_module
.
padding
,
lora_layer
=
lora_layer
,
scale
=
scale
,
)
.
to
(
weight
.
dtype
)
.
to
(
weight
.
device
)
)
_tmp
.
weight
=
weight
if
bias
is
not
None
:
_tmp
.
bias
=
bias
else
:
continue
_module
.
_modules
[
name
]
=
_tmp
# print('injecting lora layer to', _module, name)
return
def
update_lora_scale
(
model
:
nn
.
Module
,
target_module
:
Set
[
str
]
=
None
,
scale
:
float
=
1.0
,
):
for
_module
,
name
,
_child_module
in
_find_modules
(
model
,
target_module
,
search_class
=
[
LoRACompatibleLinear
,
LoRACompatibleConv
]
):
_child_module
.
scale
=
scale
return
cogvideox-based/sat/sgm/modules/diffusionmodules/loss.py
0 → 100644
View file @
1f5da520
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
omegaconf
import
ListConfig
import
math
from
...modules.diffusionmodules.sampling
import
VideoDDIMSampler
,
VPSDEDPMPP2MSampler
from
...util
import
append_dims
,
instantiate_from_config
from
...modules.autoencoding.lpips.loss.lpips
import
LPIPS
# import rearrange
from
einops
import
rearrange
import
random
from
sat
import
mpu
class
StandardDiffusionLoss
(
nn
.
Module
):
def
__init__
(
self
,
sigma_sampler_config
,
type
=
"df"
,
offset_noise_level
=
0.0
,
batch2model_keys
:
Optional
[
Union
[
str
,
List
[
str
],
ListConfig
]]
=
None
,
):
super
().
__init__
()
assert
type
in
[
"l2"
,
"l1"
,
"lpips"
,
'df'
]
self
.
sigma_sampler
=
instantiate_from_config
(
sigma_sampler_config
)
self
.
type
=
type
self
.
offset_noise_level
=
offset_noise_level
if
type
==
"lpips"
:
self
.
lpips
=
LPIPS
().
eval
()
if
not
batch2model_keys
:
batch2model_keys
=
[]
if
isinstance
(
batch2model_keys
,
str
):
batch2model_keys
=
[
batch2model_keys
]
self
.
batch2model_keys
=
set
(
batch2model_keys
)
def
__call__
(
self
,
network
,
denoiser
,
conditioner
,
input
,
batch
):
cond
=
conditioner
(
batch
)
additional_model_inputs
=
{
key
:
batch
[
key
]
for
key
in
self
.
batch2model_keys
.
intersection
(
batch
)}
sigmas
=
self
.
sigma_sampler
(
input
.
shape
[
0
]).
to
(
input
.
device
)
noise
=
torch
.
randn_like
(
input
)
if
self
.
offset_noise_level
>
0.0
:
noise
=
(
noise
+
append_dims
(
torch
.
randn
(
input
.
shape
[
0
]).
to
(
input
.
device
),
input
.
ndim
)
*
self
.
offset_noise_level
)
noise
=
noise
.
to
(
input
.
dtype
)
noised_input
=
input
.
float
()
+
noise
*
append_dims
(
sigmas
,
input
.
ndim
)
model_output
=
denoiser
(
network
,
noised_input
,
sigmas
,
cond
,
**
additional_model_inputs
)
w
=
append_dims
(
denoiser
.
w
(
sigmas
),
input
.
ndim
)
return
self
.
get_loss
(
model_output
,
input
,
w
)
def
get_loss
(
self
,
model_output
,
target
,
w
):
if
self
.
type
==
"l2"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"l1"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
).
abs
()).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"lpips"
:
loss
=
self
.
lpips
(
model_output
,
target
).
reshape
(
-
1
)
return
loss
class
VideoDiffusionLoss
(
StandardDiffusionLoss
):
def
__init__
(
self
,
block_scale
=
None
,
block_size
=
None
,
min_snr_value
=
None
,
fixed_frames
=
0
,
**
kwargs
):
self
.
fixed_frames
=
fixed_frames
self
.
block_scale
=
block_scale
self
.
block_size
=
block_size
self
.
min_snr_value
=
min_snr_value
super
().
__init__
(
**
kwargs
)
def
__call__
(
self
,
network
,
denoiser
,
conditioner
,
input
,
batch
):
cond
=
conditioner
(
batch
)
additional_model_inputs
=
{
key
:
batch
[
key
]
for
key
in
self
.
batch2model_keys
.
intersection
(
batch
)}
alphas_cumprod_sqrt
,
idx
=
self
.
sigma_sampler
(
input
.
shape
[
0
],
return_idx
=
True
)
alphas_cumprod_sqrt
=
alphas_cumprod_sqrt
.
to
(
input
.
device
)
idx
=
idx
.
to
(
input
.
device
)
noise
=
torch
.
randn_like
(
input
)
# broadcast noise
mp_size
=
mpu
.
get_model_parallel_world_size
()
global_rank
=
torch
.
distributed
.
get_rank
()
//
mp_size
src
=
global_rank
*
mp_size
torch
.
distributed
.
broadcast
(
idx
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
noise
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
alphas_cumprod_sqrt
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
additional_model_inputs
[
"idx"
]
=
idx
if
self
.
offset_noise_level
>
0.0
:
noise
=
(
noise
+
append_dims
(
torch
.
randn
(
input
.
shape
[
0
]).
to
(
input
.
device
),
input
.
ndim
)
*
self
.
offset_noise_level
)
noised_input
=
input
.
float
()
*
append_dims
(
alphas_cumprod_sqrt
,
input
.
ndim
)
+
noise
*
append_dims
(
(
1
-
alphas_cumprod_sqrt
**
2
)
**
0.5
,
input
.
ndim
)
model_output
=
denoiser
(
network
,
noised_input
,
alphas_cumprod_sqrt
,
cond
,
**
additional_model_inputs
)
w
=
append_dims
(
1
/
(
1
-
alphas_cumprod_sqrt
**
2
),
input
.
ndim
)
# v-pred
if
self
.
min_snr_value
is
not
None
:
w
=
min
(
w
,
self
.
min_snr_value
)
return
self
.
get_loss
(
model_output
,
input
,
w
)
def
get_loss
(
self
,
model_output
,
target
,
w
):
if
self
.
type
==
"l2"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"l1"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
).
abs
()).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"lpips"
:
loss
=
self
.
lpips
(
model_output
,
target
).
reshape
(
-
1
)
return
loss
def
fourier_transform
(
x
,
balance
=
None
):
"""
Apply Fourier transform to the input tensor and separate it into low-frequency and high-frequency components.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]
balance (torch.Tensor or float, optional): Learnable balance parameter for adjusting the cutoff frequency.
Returns:
low_freq (torch.Tensor): Low-frequency components (with real and imaginary parts)
high_freq (torch.Tensor): High-frequency components (with real and imaginary parts)
"""
# Perform 2D Real Fourier transform (rfft2 only computes positive frequencies)
x
=
x
.
to
(
torch
.
float32
)
fft_x
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
))
# Calculate magnitude of frequency components
magnitude
=
torch
.
abs
(
fft_x
)
# Set cutoff based on balance or default to the 80th percentile of the magnitude for low frequency
if
balance
is
None
:
# Downsample the magnitude to reduce computation for large tensors
subsample_size
=
10000
# Adjust based on available memory and tensor size
if
magnitude
.
numel
()
>
subsample_size
:
# Randomly select a subset of values to approximate the quantile
magnitude_sample
=
magnitude
.
flatten
()[
torch
.
randint
(
0
,
magnitude
.
numel
(),
(
subsample_size
,))]
cutoff
=
torch
.
quantile
(
magnitude_sample
,
0.8
)
# 80th percentile for low frequency
else
:
cutoff
=
torch
.
quantile
(
magnitude
,
0.8
)
# 80th percentile for low frequency
else
:
# balance is clamped for safety and used to scale the mean-based cutoff
cutoff
=
magnitude
.
mean
()
*
(
1
+
10
*
balance
)
# Smooth mask using sigmoid to ensure gradients can pass through
sharpness
=
10
# A parameter to control the sharpness of the transition
low_freq_mask
=
torch
.
sigmoid
(
sharpness
*
(
cutoff
-
magnitude
))
# High-frequency mask can be derived from low-frequency mask (1 - low_freq_mask)
high_freq_mask
=
1
-
low_freq_mask
# Separate low and high frequencies using smooth masks
low_freq
=
fft_x
*
low_freq_mask
high_freq
=
fft_x
*
high_freq_mask
# Return real and imaginary parts separately
low_freq
=
torch
.
stack
([
low_freq
.
real
,
low_freq
.
imag
],
dim
=-
1
)
high_freq
=
torch
.
stack
([
high_freq
.
real
,
high_freq
.
imag
],
dim
=-
1
)
return
low_freq
,
high_freq
def
extract_frequencies
(
video
:
torch
.
Tensor
,
balance
=
None
):
"""
Extract high-frequency and low-frequency components of a video using Fourier transform.
Args:
video (torch.Tensor): Input video tensor of shape [batch_size, channels, frames, height, width]
Returns:
low_freq (torch.Tensor): Low-frequency components of the video
high_freq (torch.Tensor): High-frequency components of the video
"""
# batch_size, channels, frames, _, _ = video.shape
video
=
rearrange
(
video
,
'b c t h w -> (b t) c h w'
)
# Reshape for Fourier transform
# Apply Fourier transform to each frame
low_freq
,
high_freq
=
fourier_transform
(
video
,
balance
=
balance
)
return
low_freq
,
high_freq
class
SRDiffusionLoss
(
StandardDiffusionLoss
):
def
__init__
(
self
,
block_scale
=
None
,
block_size
=
None
,
min_snr_value
=
None
,
fixed_frames
=
0
,
**
kwargs
):
self
.
fixed_frames
=
fixed_frames
self
.
block_scale
=
block_scale
self
.
block_size
=
block_size
self
.
min_snr_value
=
min_snr_value
super
().
__init__
(
**
kwargs
)
def
__call__
(
self
,
network
,
denoiser
,
conditioner
,
input
,
batch
,
hq_video
=
None
,
decode_first_stage
=
None
):
cond
=
conditioner
(
batch
)
additional_model_inputs
=
{
key
:
batch
[
key
]
for
key
in
self
.
batch2model_keys
.
intersection
(
batch
)}
alphas_cumprod_sqrt
,
idx
=
self
.
sigma_sampler
(
input
.
shape
[
0
],
return_idx
=
True
)
alphas_cumprod_sqrt
=
alphas_cumprod_sqrt
.
to
(
input
.
device
)
idx
=
idx
.
to
(
input
.
device
)
noise
=
torch
.
randn_like
(
input
)
# broadcast noise
mp_size
=
mpu
.
get_model_parallel_world_size
()
global_rank
=
torch
.
distributed
.
get_rank
()
//
mp_size
src
=
global_rank
*
mp_size
torch
.
distributed
.
broadcast
(
idx
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
noise
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
alphas_cumprod_sqrt
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
additional_model_inputs
[
"idx"
]
=
idx
if
self
.
offset_noise_level
>
0.0
:
noise
=
(
noise
+
append_dims
(
torch
.
randn
(
input
.
shape
[
0
]).
to
(
input
.
device
),
input
.
ndim
)
*
self
.
offset_noise_level
)
noised_input
=
input
.
float
()
*
append_dims
(
alphas_cumprod_sqrt
,
input
.
ndim
)
+
noise
*
append_dims
(
(
1
-
alphas_cumprod_sqrt
**
2
)
**
0.5
,
input
.
ndim
)
# Uncommnet for SR training
noised_input
=
torch
.
cat
((
noised_input
,
batch
[
'lq'
]),
dim
=
2
)
# [B, T /4, 32, 60, 90]
model_output
=
denoiser
(
network
,
noised_input
,
alphas_cumprod_sqrt
,
cond
,
**
additional_model_inputs
)
w
=
append_dims
(
1
/
(
1
-
alphas_cumprod_sqrt
**
2
),
input
.
ndim
)
# v-pred
if
self
.
min_snr_value
is
not
None
:
w
=
min
(
w
,
self
.
min_snr_value
)
if
self
.
type
==
"df"
:
# print('idx:', idx)
return
self
.
get_loss
(
model_output
,
input
,
w
,
hq_video
,
idx
,
decode_first_stage
)
else
:
return
self
.
get_loss
(
model_output
,
input
,
w
)
def
get_loss
(
self
,
model_output
,
target
,
w
,
video_data
=
None
,
timesteps
=
None
,
decode_first_stage
=
None
):
# model_output: x_hat_0; target: x_0
if
self
.
type
==
"l2"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"l1"
:
return
torch
.
mean
((
w
*
(
model_output
-
target
).
abs
()).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
elif
self
.
type
==
"lpips"
:
loss
=
self
.
lpips
(
model_output
,
target
).
reshape
(
-
1
)
return
loss
elif
self
.
type
==
"df"
:
# v-prediction loss
loss_v
=
torch
.
mean
((
w
*
(
model_output
-
target
)
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
with
torch
.
no_grad
():
model_output
=
model_output
.
to
(
torch
.
bfloat16
)
model_output
=
model_output
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
pred_x0
=
decode_first_stage
(
model_output
)
# print('pred_x0:', pred_x0.shape) # [1, 3, 25, 480, 720]
# print('video_data:', video_data.shape) # [1, 3, 25, 480, 720]
low_freq_pred_x0
,
high_freq_pred_x0
=
extract_frequencies
(
pred_x0
)
low_freq_x0
,
high_freq_x0
=
extract_frequencies
(
video_data
)
# timestep-aware loss
alpha
=
2
ct
=
(
timesteps
/
999
)
**
alpha
loss_low
=
F
.
l1_loss
(
low_freq_pred_x0
.
float
(),
low_freq_x0
.
float
(),
reduction
=
"mean"
)
loss_high
=
F
.
l1_loss
(
high_freq_pred_x0
.
float
(),
high_freq_x0
.
float
(),
reduction
=
"mean"
)
loss_t
=
0.01
*
(
ct
*
loss_low
+
(
1
-
ct
)
*
loss_high
)
beta
=
1
# 1 is the default setting
weight_t
=
1
-
timesteps
/
999
loss
=
loss_v
+
beta
*
weight_t
*
loss_t
# print('loss_v:', loss_v.mean().item(), 'loss_t:', (beta * weight_t * loss_t).mean().item())
return
loss
\ No newline at end of file
cogvideox-based/sat/sgm/modules/diffusionmodules/model.py
0 → 100644
View file @
1f5da520
# pytorch_diffusion + derived encoder decoder
import
math
from
typing
import
Any
,
Callable
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
packaging
import
version
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILABLE
=
True
except
:
XFORMERS_IS_AVAILABLE
=
False
print
(
"no module 'xformers'. Processing without..."
)
from
...modules.attention
import
LinearAttention
,
MemoryEfficientCrossAttention
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
LinAttnBlock
(
LinearAttention
):
"""to match AttnBlock usage"""
def
__init__
(
self
,
in_channels
):
super
().
__init__
(
dim
=
in_channels
,
heads
=
1
,
dim_head
=
in_channels
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
attention
(
self
,
h_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
b
,
c
,
h
,
w
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
"b c h w -> b 1 (h w) c"
).
contiguous
(),
(
q
,
k
,
v
))
h_
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
# scale is dim ** -0.5 per default
# compute attention
return
rearrange
(
h_
,
"b 1 (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
def
forward
(
self
,
x
,
**
kwargs
):
h_
=
x
h_
=
self
.
attention
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
MemoryEfficientAttnBlock
(
nn
.
Module
):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
attention_op
:
Optional
[
Any
]
=
None
def
attention
(
self
,
h_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
B
,
C
,
H
,
W
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
"b c h w -> b (h w) c"
),
(
q
,
k
,
v
))
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
B
,
t
.
shape
[
1
],
1
,
C
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
B
*
1
,
t
.
shape
[
1
],
C
)
.
contiguous
(),
(
q
,
k
,
v
),
)
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
out
.
unsqueeze
(
0
).
reshape
(
B
,
1
,
out
.
shape
[
1
],
C
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
,
out
.
shape
[
1
],
C
)
return
rearrange
(
out
,
"b (h w) c -> b c h w"
,
b
=
B
,
h
=
H
,
w
=
W
,
c
=
C
)
def
forward
(
self
,
x
,
**
kwargs
):
h_
=
x
h_
=
self
.
attention
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
MemoryEfficientCrossAttentionWrapper
(
MemoryEfficientCrossAttention
):
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
**
unused_kwargs
):
b
,
c
,
h
,
w
=
x
.
shape
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
out
=
super
().
forward
(
x
,
context
=
context
,
mask
=
mask
)
out
=
rearrange
(
out
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
)
return
x
+
out
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
):
assert
attn_type
in
[
"vanilla"
,
"vanilla-xformers"
,
"memory-efficient-cross-attn"
,
"linear"
,
"none"
,
],
f
"attn_type
{
attn_type
}
unknown"
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.0.0"
)
and
attn_type
!=
"none"
:
assert
XFORMERS_IS_AVAILABLE
,
(
f
"We do not support vanilla attention in
{
torch
.
__version__
}
anymore, "
f
"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type
=
"vanilla-xformers"
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
assert
attn_kwargs
is
None
return
AttnBlock
(
in_channels
)
elif
attn_type
==
"vanilla-xformers"
:
print
(
f
"building MemoryEfficientAttnBlock with
{
in_channels
}
in_channels..."
)
return
MemoryEfficientAttnBlock
(
in_channels
)
elif
type
==
"memory-efficient-cross-attn"
:
attn_kwargs
[
"query_dim"
]
=
in_channels
return
MemoryEfficientCrossAttentionWrapper
(
**
attn_kwargs
)
elif
attn_type
==
"none"
:
return
nn
.
Identity
(
in_channels
)
else
:
return
LinAttnBlock
(
in_channels
)
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
,
context
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
context
is
not
None
:
# assume aligned context, cat along channel axis
x
=
torch
.
cat
((
x
,
context
),
dim
=
1
)
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
def
get_last_layer
(
self
):
return
self
.
conv_out
.
weight
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignore_kwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
)
def
forward
(
self
,
x
):
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignorekwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
self
.
tanh_out
=
tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
make_attn_cls
=
self
.
_make_attn
()
make_resblock_cls
=
self
.
_make_resblock
()
make_conv_cls
=
self
.
_make_conv
()
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn_cls
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
make_resblock_cls
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn_cls
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
make_conv_cls
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
_make_attn
(
self
)
->
Callable
:
return
make_attn
def
_make_resblock
(
self
)
->
Callable
:
return
ResnetBlock
def
_make_conv
(
self
)
->
Callable
:
return
torch
.
nn
.
Conv2d
def
get_last_layer
(
self
,
**
kwargs
):
return
self
.
conv_out
.
weight
def
forward
(
self
,
z
,
**
kwargs
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
,
**
kwargs
)
h
=
self
.
mid
.
attn_1
(
h
,
**
kwargs
)
h
=
self
.
mid
.
block_2
(
h
,
temb
,
**
kwargs
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
,
**
kwargs
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
,
**
kwargs
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
,
**
kwargs
)
if
self
.
tanh_out
:
h
=
torch
.
tanh
(
h
)
return
h
cogvideox-based/sat/sgm/modules/diffusionmodules/openaimodel.py
0 → 100644
View file @
1f5da520
import
os
import
math
from
abc
import
abstractmethod
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
...modules.attention
import
SpatialTransformer
from
...modules.diffusionmodules.util
import
(
avg_pool_nd
,
checkpoint
,
conv_nd
,
linear
,
normalization
,
timestep_embedding
,
zero_module
,
)
from
...modules.diffusionmodules.lora
import
inject_trainable_lora_extended
,
update_lora_scale
from
...modules.video_attention
import
SpatialVideoTransformer
from
...util
import
default
,
exists
# dummy replace
def
convert_module_to_f16
(
x
):
pass
def
convert_module_to_f32
(
x
):
pass
## go
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
th
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
th
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
:
th
.
Tensor
,
emb
:
th
.
Tensor
,
context
:
Optional
[
th
.
Tensor
]
=
None
,
image_only_indicator
:
Optional
[
th
.
Tensor
]
=
None
,
time_context
:
Optional
[
int
]
=
None
,
num_video_frames
:
Optional
[
int
]
=
None
,
):
from
...modules.diffusionmodules.video_model
import
VideoResBlock
for
layer
in
self
:
module
=
layer
if
isinstance
(
module
,
TimestepBlock
)
and
not
isinstance
(
module
,
VideoResBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
module
,
VideoResBlock
):
x
=
layer
(
x
,
emb
,
num_video_frames
,
image_only_indicator
)
elif
isinstance
(
module
,
SpatialVideoTransformer
):
x
=
layer
(
x
,
context
,
time_context
,
num_video_frames
,
image_only_indicator
,
)
elif
isinstance
(
module
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
else
:
x
=
layer
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
third_up
=
False
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
third_up
=
third_up
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
t_factor
=
1
if
not
self
.
third_up
else
2
x
=
F
.
interpolate
(
x
,
(
t_factor
*
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
,
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
TransposedUpsample
(
nn
.
Module
):
"Learned 2x upsampling without padding"
def
__init__
(
self
,
channels
,
out_channels
=
None
,
ks
=
5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
up
=
nn
.
ConvTranspose2d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
ks
,
stride
=
2
)
def
forward
(
self
,
x
):
return
self
.
up
(
x
)
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
third_down
=
False
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
((
1
,
2
,
2
)
if
not
third_down
else
(
2
,
2
,
2
))
if
use_conv
:
print
(
f
"Building a Downsample layer with
{
dims
}
dims."
)
print
(
f
" --> settings are:
\n
in-chn:
{
self
.
channels
}
, out-chn:
{
self
.
out_channels
}
, "
f
"kernel-size: 3, stride:
{
stride
}
, padding:
{
padding
}
"
)
if
dims
==
3
:
print
(
f
" --> Downsampling third axis (time):
{
third_down
}
"
)
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
3
,
exchange_temb_dims
=
False
,
skip_t_emb
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
exchange_temb_dims
=
exchange_temb_dims
if
isinstance
(
kernel_size
,
Iterable
):
padding
=
[
k
//
2
for
k
in
kernel_size
]
else
:
padding
=
kernel_size
//
2
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
skip_t_emb
=
skip_t_emb
self
.
emb_out_channels
=
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
if
self
.
skip_t_emb
:
print
(
f
"Skipping timestep embedding in
{
self
.
__class__
.
__name__
}
"
)
assert
not
self
.
use_scale_shift_norm
self
.
emb_layers
=
None
self
.
exchange_temb_dims
=
False
else
:
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
self
.
emb_out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
)
),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return
checkpoint
(
self
.
_forward
,
(
x
,
emb
),
self
.
parameters
(),
self
.
use_checkpoint
)
def
_forward
(
self
,
x
,
emb
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
if
self
.
skip_t_emb
:
emb_out
=
th
.
zeros_like
(
h
)
else
:
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
if
self
.
exchange_temb_dims
:
emb_out
=
rearrange
(
emb_out
,
"b t c ... -> b c t ..."
)
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
if
use_new_attention_order
:
# split qkv before split heads
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
else
:
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
**
kwargs
):
# TODO add crossframe attention and use mixed checkpoint
return
checkpoint
(
self
.
_forward
,
(
x
,),
self
.
parameters
(),
True
)
# TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def
_forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
def
count_flops_attn
(
model
,
_x
,
y
):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b
,
c
,
*
spatial
=
y
[
0
].
shape
num_spatial
=
int
(
np
.
prod
(
spatial
))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
model
.
total_ops
+=
th
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
t
):
return
timestep_embedding
(
t
,
self
.
dim
)
str_to_dtype
=
{
"fp32"
:
th
.
float32
,
"fp16"
:
th
.
float16
,
"bf16"
:
th
.
bfloat16
}
class
UNetModel
(
nn
.
Module
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
use_spatial_transformer
=
False
,
# custom transformer support
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
disable_self_attentions
=
None
,
num_attention_blocks
=
None
,
disable_middle_self_attn
=
False
,
use_linear_in_transformer
=
False
,
spatial_transformer_attn_type
=
"softmax"
,
adm_in_channels
=
None
,
use_fairscale_checkpoint
=
False
,
offload_to_cpu
=
False
,
transformer_depth_middle
=
None
,
dtype
=
"fp32"
,
lora_init
=
False
,
lora_rank
=
4
,
lora_scale
=
1.0
,
lora_weight_path
=
None
,
):
super
().
__init__
()
from
omegaconf.listconfig
import
ListConfig
self
.
dtype
=
str_to_dtype
[
dtype
]
if
use_spatial_transformer
:
assert
(
context_dim
is
not
None
),
"Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if
context_dim
is
not
None
:
assert
(
use_spatial_transformer
),
"Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if
type
(
context_dim
)
==
ListConfig
:
context_dim
=
list
(
context_dim
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
"Either num_heads or num_head_channels has to be set"
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
"Either num_heads or num_head_channels has to be set"
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
elif
isinstance
(
transformer_depth
,
ListConfig
):
transformer_depth
=
list
(
transformer_depth
)
transformer_depth_middle
=
default
(
transformer_depth_middle
,
transformer_depth
[
-
1
])
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
if
len
(
num_res_blocks
)
!=
len
(
channel_mult
):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
# self.num_res_blocks = num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
)),
)
)
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
# todo: convert to warning
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
if
use_fp16
:
print
(
"WARNING: use_fp16 was dropped and has no effect anymore."
)
# self.dtype = th.float16 if use_fp16 else th.float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
assert
use_fairscale_checkpoint
!=
use_checkpoint
or
not
(
use_checkpoint
or
use_fairscale_checkpoint
)
self
.
use_fairscale_checkpoint
=
False
checkpoint_wrapper_fn
=
(
partial
(
checkpoint_wrapper
,
offload_to_cpu
=
offload_to_cpu
)
if
self
.
use_fairscale_checkpoint
else
lambda
x
:
x
)
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
if
self
.
num_classes
is
not
None
:
if
isinstance
(
self
.
num_classes
,
int
):
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
elif
self
.
num_classes
==
"continuous"
:
print
(
"setting up linear c_adm embedding layer"
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
elif
self
.
num_classes
==
"timestep"
:
self
.
label_emb
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
Timestep
(
model_channels
),
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
),
)
)
elif
self
.
num_classes
==
"sequential"
:
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
else
:
raise
ValueError
()
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
),
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)
),
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
self
.
num_res_blocks
[
level
]
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
checkpoint_wrapper_fn
(
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
)
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
checkpoint_wrapper_fn
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
else
checkpoint_wrapper_fn
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
attn_type
=
spatial_transformer_attn_type
,
use_checkpoint
=
use_checkpoint
,
)
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
out_ch
=
ch
layers
.
append
(
checkpoint_wrapper_fn
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
)
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
checkpoint_wrapper_fn
(
nn
.
Sequential
(
normalization
(
ch
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
)
if
lora_init
:
self
.
_init_lora
(
lora_rank
,
lora_scale
,
lora_weight_path
)
def
_init_lora
(
self
,
rank
,
scale
,
ckpt_dir
=
None
):
inject_trainable_lora_extended
(
self
,
target_replace_module
=
None
,
rank
=
rank
,
scale
=
scale
)
if
ckpt_dir
is
not
None
:
with
open
(
os
.
path
.
join
(
ckpt_dir
,
"latest"
))
as
latest_file
:
latest
=
latest_file
.
read
().
strip
()
ckpt_path
=
os
.
path
.
join
(
ckpt_dir
,
latest
,
"mp_rank_00_model_states.pt"
)
print
(
f
"loading lora from
{
ckpt_path
}
"
)
sd
=
th
.
load
(
ckpt_path
)[
"module"
]
sd
=
{
key
[
len
(
"model.diffusion_model"
)
:]:
sd
[
key
]
for
key
in
sd
if
key
.
startswith
(
"model.diffusion_model"
)
}
self
.
load_state_dict
(
sd
,
strict
=
False
)
def
_update_scale
(
self
,
scale
):
update_lora_scale
(
self
,
scale
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
self
.
output_blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
,
dtype
=
self
.
dtype
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
[
0
]
==
x
.
shape
[
0
]
emb
=
emb
+
self
.
label_emb
(
y
)
# h = x.type(self.dtype)
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
assert
False
,
"not supported anymore. what the f*** are you doing?"
else
:
return
self
.
out
(
h
)
class
NoTimeUNetModel
(
UNetModel
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
timesteps
=
th
.
zeros_like
(
timesteps
)
return
super
().
forward
(
x
,
timesteps
,
context
,
y
,
**
kwargs
)
class
EncoderUNetModel
(
nn
.
Module
):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
pool
=
"adaptive"
,
*
args
,
**
kwargs
,
):
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
pool
=
pool
if
pool
==
"adaptive"
:
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
zero_module
(
conv_nd
(
dims
,
ch
,
out_channels
,
1
)),
nn
.
Flatten
(),
)
elif
pool
==
"attention"
:
assert
num_head_channels
!=
-
1
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
elif
pool
==
"spatial"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
nn
.
ReLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
elif
pool
==
"spatial_v2"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
normalization
(
2048
),
nn
.
SiLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
else
:
raise
NotImplementedError
(
f
"Unexpected
{
pool
}
pooling"
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
results
=
[]
# h = x.type(self.dtype)
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
self
.
middle_block
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
th
.
cat
(
results
,
axis
=-
1
)
return
self
.
out
(
h
)
else
:
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
if
__name__
==
"__main__"
:
class
Dummy
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
64
):
super
().
__init__
()
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
2
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
model
=
UNetModel
(
use_checkpoint
=
True
,
image_size
=
64
,
in_channels
=
4
,
out_channels
=
4
,
model_channels
=
128
,
attention_resolutions
=
[
4
,
2
],
num_res_blocks
=
2
,
channel_mult
=
[
1
,
2
,
4
],
num_head_channels
=
64
,
use_spatial_transformer
=
False
,
use_linear_in_transformer
=
True
,
transformer_depth
=
1
,
legacy
=
False
,
).
cuda
()
x
=
th
.
randn
(
11
,
4
,
64
,
64
).
cuda
()
t
=
th
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
11
,),
device
=
"cuda"
)
o
=
model
(
x
,
t
)
print
(
"done."
)
cogvideox-based/sat/sgm/modules/diffusionmodules/sampling.py
0 → 100644
View file @
1f5da520
"""
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
from
typing
import
Dict
,
Union
import
torch
from
omegaconf
import
ListConfig
,
OmegaConf
from
tqdm
import
tqdm
from
...modules.diffusionmodules.sampling_utils
import
(
get_ancestral_step
,
linear_multistep_coeff
,
to_d
,
to_neg_log_sigma
,
to_sigma
,
)
from
...util
import
append_dims
,
default
,
instantiate_from_config
from
...util
import
SeededNoise
from
.guiders
import
DynamicCFG
DEFAULT_GUIDER
=
{
"target"
:
"sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
class
BaseDiffusionSampler
:
def
__init__
(
self
,
discretization_config
:
Union
[
Dict
,
ListConfig
,
OmegaConf
],
num_steps
:
Union
[
int
,
None
]
=
None
,
guider_config
:
Union
[
Dict
,
ListConfig
,
OmegaConf
,
None
]
=
None
,
verbose
:
bool
=
False
,
device
:
str
=
"cuda"
,
):
self
.
num_steps
=
num_steps
self
.
discretization
=
instantiate_from_config
(
discretization_config
)
self
.
guider
=
instantiate_from_config
(
default
(
guider_config
,
DEFAULT_GUIDER
,
)
)
self
.
verbose
=
verbose
self
.
device
=
device
def
prepare_sampling_loop
(
self
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
sigmas
=
self
.
discretization
(
self
.
num_steps
if
num_steps
is
None
else
num_steps
,
device
=
self
.
device
)
uc
=
default
(
uc
,
cond
)
x
*=
torch
.
sqrt
(
1.0
+
sigmas
[
0
]
**
2.0
)
num_sigmas
=
len
(
sigmas
)
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]]).
float
()
return
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
def
denoise
(
self
,
x
,
denoiser
,
sigma
,
cond
,
uc
):
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
sigma
,
cond
,
uc
))
denoised
=
self
.
guider
(
denoised
,
sigma
)
return
denoised
def
get_sigma_gen
(
self
,
num_sigmas
):
sigma_generator
=
range
(
num_sigmas
-
1
)
if
self
.
verbose
:
print
(
"#"
*
30
,
" Sampling setting "
,
"#"
*
30
)
print
(
f
"Sampler:
{
self
.
__class__
.
__name__
}
"
)
print
(
f
"Discretization:
{
self
.
discretization
.
__class__
.
__name__
}
"
)
print
(
f
"Guider:
{
self
.
guider
.
__class__
.
__name__
}
"
)
sigma_generator
=
tqdm
(
sigma_generator
,
total
=
num_sigmas
,
desc
=
f
"Sampling with
{
self
.
__class__
.
__name__
}
for
{
num_sigmas
}
steps"
,
)
return
sigma_generator
class
SingleStepDiffusionSampler
(
BaseDiffusionSampler
):
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
euler_step
(
self
,
x
,
d
,
dt
):
return
x
+
dt
*
d
class
EDMSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
s_churn
=
0.0
,
s_tmin
=
0.0
,
s_tmax
=
float
(
"inf"
),
s_noise
=
1.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
s_churn
=
s_churn
self
.
s_tmin
=
s_tmin
self
.
s_tmax
=
s_tmax
self
.
s_noise
=
s_noise
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
gamma
=
0.0
):
sigma_hat
=
sigma
*
(
gamma
+
1.0
)
if
gamma
>
0
:
eps
=
torch
.
randn_like
(
x
)
*
self
.
s_noise
x
=
x
+
eps
*
append_dims
(
sigma_hat
**
2
-
sigma
**
2
,
x
.
ndim
)
**
0.5
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma_hat
,
cond
,
uc
)
d
=
to_d
(
x
,
sigma_hat
,
denoised
)
dt
=
append_dims
(
next_sigma
-
sigma_hat
,
x
.
ndim
)
euler_step
=
self
.
euler_step
(
x
,
d
,
dt
)
x
=
self
.
possible_correction_step
(
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
gamma
=
(
min
(
self
.
s_churn
/
(
num_sigmas
-
1
),
2
**
0.5
-
1
)
if
self
.
s_tmin
<=
sigmas
[
i
]
<=
self
.
s_tmax
else
0.0
)
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
gamma
,
)
return
x
class
DDIMSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
s_noise
=
0.1
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
s_noise
=
s_noise
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
s_noise
=
0.0
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
d
=
to_d
(
x
,
sigma
,
denoised
)
dt
=
append_dims
(
next_sigma
*
(
1
-
s_noise
**
2
)
**
0.5
-
sigma
,
x
.
ndim
)
euler_step
=
x
+
dt
*
d
+
s_noise
*
append_dims
(
next_sigma
,
x
.
ndim
)
*
torch
.
randn_like
(
x
)
x
=
self
.
possible_correction_step
(
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
self
.
s_noise
,
)
return
x
class
AncestralSampler
(
SingleStepDiffusionSampler
):
def
__init__
(
self
,
eta
=
1.0
,
s_noise
=
1.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
eta
=
eta
self
.
s_noise
=
s_noise
self
.
noise_sampler
=
lambda
x
:
torch
.
randn_like
(
x
)
def
ancestral_euler_step
(
self
,
x
,
denoised
,
sigma
,
sigma_down
):
d
=
to_d
(
x
,
sigma
,
denoised
)
dt
=
append_dims
(
sigma_down
-
sigma
,
x
.
ndim
)
return
self
.
euler_step
(
x
,
d
,
dt
)
def
ancestral_step
(
self
,
x
,
sigma
,
next_sigma
,
sigma_up
):
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x
+
self
.
noise_sampler
(
x
)
*
self
.
s_noise
*
append_dims
(
sigma_up
,
x
.
ndim
),
x
,
)
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
)
return
x
class
LinearMultistepSampler
(
BaseDiffusionSampler
):
def
__init__
(
self
,
order
=
4
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
order
=
order
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
ds
=
[]
sigmas_cpu
=
sigmas
.
detach
().
cpu
().
numpy
()
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
sigma
=
s_in
*
sigmas
[
i
]
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
sigma
,
cond
,
uc
),
**
kwargs
)
denoised
=
self
.
guider
(
denoised
,
sigma
)
d
=
to_d
(
x
,
sigma
,
denoised
)
ds
.
append
(
d
)
if
len
(
ds
)
>
self
.
order
:
ds
.
pop
(
0
)
cur_order
=
min
(
i
+
1
,
self
.
order
)
coeffs
=
[
linear_multistep_coeff
(
cur_order
,
sigmas_cpu
,
i
,
j
)
for
j
in
range
(
cur_order
)]
x
=
x
+
sum
(
coeff
*
d
for
coeff
,
d
in
zip
(
coeffs
,
reversed
(
ds
)))
return
x
class
EulerEDMSampler
(
EDMSampler
):
def
possible_correction_step
(
self
,
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
):
return
euler_step
class
HeunEDMSampler
(
EDMSampler
):
def
possible_correction_step
(
self
,
euler_step
,
x
,
d
,
dt
,
next_sigma
,
denoiser
,
cond
,
uc
):
if
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0
return
euler_step
else
:
denoised
=
self
.
denoise
(
euler_step
,
denoiser
,
next_sigma
,
cond
,
uc
)
d_new
=
to_d
(
euler_step
,
next_sigma
,
denoised
)
d_prime
=
(
d
+
d_new
)
/
2.0
# apply correction if noise level is not 0
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x
+
d_prime
*
dt
,
euler_step
)
return
x
class
EulerAncestralSampler
(
AncestralSampler
):
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
):
sigma_down
,
sigma_up
=
get_ancestral_step
(
sigma
,
next_sigma
,
eta
=
self
.
eta
)
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
x
=
self
.
ancestral_euler_step
(
x
,
denoised
,
sigma
,
sigma_down
)
x
=
self
.
ancestral_step
(
x
,
sigma
,
next_sigma
,
sigma_up
)
return
x
class
DPMPP2SAncestralSampler
(
AncestralSampler
):
def
get_variables
(
self
,
sigma
,
sigma_down
):
t
,
t_next
=
[
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
sigma_down
)]
h
=
t_next
-
t
s
=
t
+
0.5
*
h
return
h
,
s
,
t
,
t_next
def
get_mult
(
self
,
h
,
s
,
t
,
t_next
):
mult1
=
to_sigma
(
s
)
/
to_sigma
(
t
)
mult2
=
(
-
0.5
*
h
).
expm1
()
mult3
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
mult4
=
(
-
h
).
expm1
()
return
mult1
,
mult2
,
mult3
,
mult4
def
sampler_step
(
self
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
**
kwargs
):
sigma_down
,
sigma_up
=
get_ancestral_step
(
sigma
,
next_sigma
,
eta
=
self
.
eta
)
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
x_euler
=
self
.
ancestral_euler_step
(
x
,
denoised
,
sigma
,
sigma_down
)
if
torch
.
sum
(
sigma_down
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0
x
=
x_euler
else
:
h
,
s
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
sigma_down
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
s
,
t
,
t_next
)]
x2
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
denoised2
=
self
.
denoise
(
x2
,
denoiser
,
to_sigma
(
s
),
cond
,
uc
)
x_dpmpp2s
=
mult
[
2
]
*
x
-
mult
[
3
]
*
denoised2
# apply correction if noise level is not 0
x
=
torch
.
where
(
append_dims
(
sigma_down
,
x
.
ndim
)
>
0.0
,
x_dpmpp2s
,
x_euler
)
x
=
self
.
ancestral_step
(
x
,
sigma
,
next_sigma
,
sigma_up
)
return
x
class
DPMPP2MSampler
(
BaseDiffusionSampler
):
def
get_variables
(
self
,
sigma
,
next_sigma
,
previous_sigma
=
None
):
t
,
t_next
=
[
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
next_sigma
)]
h
=
t_next
-
t
if
previous_sigma
is
not
None
:
h_last
=
t
-
to_neg_log_sigma
(
previous_sigma
)
r
=
h_last
/
h
return
h
,
r
,
t
,
t_next
else
:
return
h
,
None
,
t
,
t_next
def
get_mult
(
self
,
h
,
r
,
t
,
t_next
,
previous_sigma
):
mult1
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
mult2
=
(
-
h
).
expm1
()
if
previous_sigma
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_sigma
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
h
,
r
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
next_sigma
,
previous_sigma
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
t
,
t_next
,
previous_sigma
)]
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
if
old_denoised
is
None
or
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
# apply correction if noise level is not 0 and not first step
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x_advanced
,
x_standard
)
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
sigmas
[
i
-
1
],
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
)
return
x
class
SDEDPMPP2MSampler
(
BaseDiffusionSampler
):
def
get_variables
(
self
,
sigma
,
next_sigma
,
previous_sigma
=
None
):
t
,
t_next
=
[
to_neg_log_sigma
(
s
)
for
s
in
(
sigma
,
next_sigma
)]
h
=
t_next
-
t
if
previous_sigma
is
not
None
:
h_last
=
t
-
to_neg_log_sigma
(
previous_sigma
)
r
=
h_last
/
h
return
h
,
r
,
t
,
t_next
else
:
return
h
,
None
,
t
,
t_next
def
get_mult
(
self
,
h
,
r
,
t
,
t_next
,
previous_sigma
):
mult1
=
to_sigma
(
t_next
)
/
to_sigma
(
t
)
*
(
-
h
).
exp
()
mult2
=
(
-
2
*
h
).
expm1
()
if
previous_sigma
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_sigma
,
sigma
,
next_sigma
,
denoiser
,
x
,
cond
,
uc
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
sigma
,
cond
,
uc
)
h
,
r
,
t
,
t_next
=
self
.
get_variables
(
sigma
,
next_sigma
,
previous_sigma
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
t
,
t_next
,
previous_sigma
)]
mult_noise
=
append_dims
(
next_sigma
*
(
1
-
(
-
2
*
h
).
exp
())
**
0.5
,
x
.
ndim
)
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
+
mult_noise
*
torch
.
randn_like
(
x
)
if
old_denoised
is
None
or
torch
.
sum
(
next_sigma
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
torch
.
randn_like
(
x
)
# apply correction if noise level is not 0 and not first step
x
=
torch
.
where
(
append_dims
(
next_sigma
,
x
.
ndim
)
>
0.0
,
x_advanced
,
x_standard
)
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
**
kwargs
):
x
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
sigmas
[
i
-
1
],
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
)
return
x
class
SdeditEDMSampler
(
EulerEDMSampler
):
def
__init__
(
self
,
edit_ratio
=
0.5
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
edit_ratio
=
edit_ratio
def
__call__
(
self
,
denoiser
,
image
,
randn
,
cond
,
uc
=
None
,
num_steps
=
None
,
edit_ratio
=
None
):
randn_unit
=
randn
.
clone
()
randn
,
s_in
,
sigmas
,
num_sigmas
,
cond
,
uc
=
self
.
prepare_sampling_loop
(
randn
,
cond
,
uc
,
num_steps
)
if
num_steps
is
None
:
num_steps
=
self
.
num_steps
if
edit_ratio
is
None
:
edit_ratio
=
self
.
edit_ratio
x
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
if
i
/
num_steps
<
edit_ratio
:
continue
if
x
is
None
:
x
=
image
+
randn_unit
*
append_dims
(
s_in
*
sigmas
[
i
],
len
(
randn_unit
.
shape
))
gamma
=
(
min
(
self
.
s_churn
/
(
num_sigmas
-
1
),
2
**
0.5
-
1
)
if
self
.
s_tmin
<=
sigmas
[
i
]
<=
self
.
s_tmax
else
0.0
)
x
=
self
.
sampler_step
(
s_in
*
sigmas
[
i
],
s_in
*
sigmas
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
gamma
,
)
return
x
class
VideoDDIMSampler
(
BaseDiffusionSampler
):
def
__init__
(
self
,
fixed_frames
=
0
,
sdedit
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
fixed_frames
=
fixed_frames
self
.
sdedit
=
sdedit
def
prepare_sampling_loop
(
self
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
):
alpha_cumprod_sqrt
,
timesteps
=
self
.
discretization
(
self
.
num_steps
if
num_steps
is
None
else
num_steps
,
device
=
self
.
device
,
return_idx
=
True
,
do_append_zero
=
False
,
)
alpha_cumprod_sqrt
=
torch
.
cat
([
alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
.
new_ones
([
1
])])
timesteps
=
torch
.
cat
([
torch
.
tensor
(
list
(
timesteps
)).
new_zeros
([
1
])
-
1
,
torch
.
tensor
(
list
(
timesteps
))])
uc
=
default
(
uc
,
cond
)
num_sigmas
=
len
(
alpha_cumprod_sqrt
)
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
return
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
def
denoise
(
self
,
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
=
None
,
idx
=
None
,
scale
=
None
,
scale_emb
=
None
,
lq
=
None
):
additional_model_inputs
=
{}
if
isinstance
(
scale
,
torch
.
Tensor
)
==
False
and
scale
==
1
:
# print('Without CFG')
additional_model_inputs
[
"idx"
]
=
x
.
new_ones
([
x
.
shape
[
0
]])
*
timestep
if
scale_emb
is
not
None
:
additional_model_inputs
[
"scale_emb"
]
=
scale_emb
denoised
=
denoiser
(
x
,
alpha_cumprod_sqrt
,
cond
,
**
additional_model_inputs
).
to
(
torch
.
float32
)
else
:
# print('Using CFG')
additional_model_inputs
[
"idx"
]
=
torch
.
cat
([
x
.
new_ones
([
x
.
shape
[
0
]])
*
timestep
]
*
2
)
denoised
=
denoiser
(
*
self
.
guider
.
prepare_inputs
(
x
,
alpha_cumprod_sqrt
,
cond
,
uc
,
lq
),
**
additional_model_inputs
).
to
(
torch
.
float32
)
# print('denoised shape:', denoised.shape) # torch.Size([2, 8, 16, 60, 90])
if
isinstance
(
self
.
guider
,
DynamicCFG
):
denoised
=
self
.
guider
(
denoised
,
(
1
-
alpha_cumprod_sqrt
**
2
)
**
0.5
,
step_index
=
self
.
num_steps
-
timestep
,
scale
=
scale
)
else
:
denoised
=
self
.
guider
(
denoised
,
(
1
-
alpha_cumprod_sqrt
**
2
)
**
0.5
,
scale
=
scale
)
return
denoised
def
sampler_step
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
scale
=
None
,
scale_emb
=
None
,
lq
=
None
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
,
scale
=
scale
,
scale_emb
=
scale_emb
,
lq
=
lq
).
to
(
torch
.
float32
)
a_t
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
b_t
=
next_alpha_cumprod_sqrt
-
alpha_cumprod_sqrt
*
a_t
x
=
append_dims
(
a_t
,
x
.
ndim
)
*
x
+
append_dims
(
b_t
,
x
.
ndim
)
*
denoised
return
x
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
scale_emb
=
None
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
=
self
.
sampler_step
(
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
scale
=
scale
,
scale_emb
=
scale_emb
,
)
return
x
class
VPSDEDPMPP2MSampler
(
VideoDDIMSampler
):
def
get_variables
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
=
None
):
alpha_cumprod
=
alpha_cumprod_sqrt
**
2
lamb
=
((
alpha_cumprod
/
(
1
-
alpha_cumprod
))
**
0.5
).
log
()
next_alpha_cumprod
=
next_alpha_cumprod_sqrt
**
2
lamb_next
=
((
next_alpha_cumprod
/
(
1
-
next_alpha_cumprod
))
**
0.5
).
log
()
h
=
lamb_next
-
lamb
if
previous_alpha_cumprod_sqrt
is
not
None
:
previous_alpha_cumprod
=
previous_alpha_cumprod_sqrt
**
2
lamb_previous
=
((
previous_alpha_cumprod
/
(
1
-
previous_alpha_cumprod
))
**
0.5
).
log
()
h_last
=
lamb
-
lamb_previous
r
=
h_last
/
h
return
h
,
r
,
lamb
,
lamb_next
else
:
return
h
,
None
,
lamb
,
lamb_next
def
get_mult
(
self
,
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
):
mult1
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
*
(
-
h
).
exp
()
mult2
=
(
-
2
*
h
).
expm1
()
*
next_alpha_cumprod_sqrt
if
previous_alpha_cumprod_sqrt
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
scale
=
None
,
scale_emb
=
None
,
lq
=
None
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
,
scale
=
scale
,
scale_emb
=
scale_emb
,
lq
=
lq
).
to
(
torch
.
float32
)
if
idx
==
1
:
return
denoised
,
denoised
h
,
r
,
lamb
,
lamb_next
=
self
.
get_variables
(
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
]
mult_noise
=
append_dims
((
1
-
next_alpha_cumprod_sqrt
**
2
)
**
0.5
*
(
1
-
(
-
2
*
h
).
exp
())
**
0.5
,
x
.
ndim
)
# print('In sampler_step denoised shape:', denoised.shape)
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
+
mult_noise
*
torch
.
randn_like
(
x
)
if
old_denoised
is
None
or
torch
.
sum
(
next_alpha_cumprod_sqrt
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
torch
.
randn_like
(
x
)
x
=
x_advanced
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
scale_emb
=
None
,
lq
=
None
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
if
self
.
fixed_frames
>
0
:
prefix_frames
=
x
[:,
:
self
.
fixed_frames
]
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
if
self
.
fixed_frames
>
0
:
if
self
.
sdedit
:
rd
=
torch
.
randn_like
(
prefix_frames
)
noised_prefix_frames
=
alpha_cumprod_sqrt
[
i
]
*
prefix_frames
+
rd
*
append_dims
(
s_in
*
(
1
-
alpha_cumprod_sqrt
[
i
]
**
2
)
**
0.5
,
len
(
prefix_frames
.
shape
)
)
x
=
torch
.
cat
([
noised_prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
else
:
x
=
torch
.
cat
([
prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
# print('before sampler_step x shape:', x.shape) # torch.Size([1, 8, 16, 60, 90])
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
alpha_cumprod_sqrt
[
i
-
1
],
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
scale
=
scale
,
scale_emb
=
scale_emb
,
lq
=
lq
,
)
if
self
.
fixed_frames
>
0
:
x
=
torch
.
cat
([
prefix_frames
,
x
[:,
self
.
fixed_frames
:]],
dim
=
1
)
return
x
class
VPODEDPMPP2MSampler
(
VideoDDIMSampler
):
def
get_variables
(
self
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
=
None
):
alpha_cumprod
=
alpha_cumprod_sqrt
**
2
lamb
=
((
alpha_cumprod
/
(
1
-
alpha_cumprod
))
**
0.5
).
log
()
next_alpha_cumprod
=
next_alpha_cumprod_sqrt
**
2
lamb_next
=
((
next_alpha_cumprod
/
(
1
-
next_alpha_cumprod
))
**
0.5
).
log
()
h
=
lamb_next
-
lamb
if
previous_alpha_cumprod_sqrt
is
not
None
:
previous_alpha_cumprod
=
previous_alpha_cumprod_sqrt
**
2
lamb_previous
=
((
previous_alpha_cumprod
/
(
1
-
previous_alpha_cumprod
))
**
0.5
).
log
()
h_last
=
lamb
-
lamb_previous
r
=
h_last
/
h
return
h
,
r
,
lamb
,
lamb_next
else
:
return
h
,
None
,
lamb
,
lamb_next
def
get_mult
(
self
,
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
):
mult1
=
((
1
-
next_alpha_cumprod_sqrt
**
2
)
/
(
1
-
alpha_cumprod_sqrt
**
2
))
**
0.5
mult2
=
(
-
h
).
expm1
()
*
next_alpha_cumprod_sqrt
if
previous_alpha_cumprod_sqrt
is
not
None
:
mult3
=
1
+
1
/
(
2
*
r
)
mult4
=
1
/
(
2
*
r
)
return
mult1
,
mult2
,
mult3
,
mult4
else
:
return
mult1
,
mult2
def
sampler_step
(
self
,
old_denoised
,
previous_alpha_cumprod_sqrt
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
denoiser
,
x
,
cond
,
uc
=
None
,
idx
=
None
,
timestep
=
None
,
):
denoised
=
self
.
denoise
(
x
,
denoiser
,
alpha_cumprod_sqrt
,
cond
,
uc
,
timestep
,
idx
).
to
(
torch
.
float32
)
if
idx
==
1
:
return
denoised
,
denoised
h
,
r
,
lamb
,
lamb_next
=
self
.
get_variables
(
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
mult
=
[
append_dims
(
mult
,
x
.
ndim
)
for
mult
in
self
.
get_mult
(
h
,
r
,
alpha_cumprod_sqrt
,
next_alpha_cumprod_sqrt
,
previous_alpha_cumprod_sqrt
)
]
x_standard
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised
if
old_denoised
is
None
or
torch
.
sum
(
next_alpha_cumprod_sqrt
)
<
1e-14
:
# Save a network evaluation if all noise levels are 0 or on the first step
return
x_standard
,
denoised
else
:
denoised_d
=
mult
[
2
]
*
denoised
-
mult
[
3
]
*
old_denoised
x_advanced
=
mult
[
0
]
*
x
-
mult
[
1
]
*
denoised_d
x
=
x_advanced
return
x
,
denoised
def
__call__
(
self
,
denoiser
,
x
,
cond
,
uc
=
None
,
num_steps
=
None
,
scale
=
None
,
**
kwargs
):
x
,
s_in
,
alpha_cumprod_sqrt
,
num_sigmas
,
cond
,
uc
,
timesteps
=
self
.
prepare_sampling_loop
(
x
,
cond
,
uc
,
num_steps
)
old_denoised
=
None
for
i
in
self
.
get_sigma_gen
(
num_sigmas
):
x
,
old_denoised
=
self
.
sampler_step
(
old_denoised
,
None
if
i
==
0
else
s_in
*
alpha_cumprod_sqrt
[
i
-
1
],
s_in
*
alpha_cumprod_sqrt
[
i
],
s_in
*
alpha_cumprod_sqrt
[
i
+
1
],
denoiser
,
x
,
cond
,
uc
=
uc
,
idx
=
self
.
num_steps
-
i
,
timestep
=
timesteps
[
-
(
i
+
1
)],
)
return
x
cogvideox-based/sat/sgm/modules/diffusionmodules/sampling_utils.py
0 → 100644
View file @
1f5da520
import
torch
from
scipy
import
integrate
from
...util
import
append_dims
from
einops
import
rearrange
class
NoDynamicThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
scale
=
append_dims
(
scale
,
cond
.
ndim
)
if
isinstance
(
scale
,
torch
.
Tensor
)
else
scale
return
uncond
+
scale
*
(
cond
-
uncond
)
class
StaticThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
result
=
uncond
+
scale
*
(
cond
-
uncond
)
result
=
torch
.
clamp
(
result
,
min
=-
1.0
,
max
=
1.0
)
return
result
def
dynamic_threshold
(
x
,
p
=
0.95
):
N
,
T
,
C
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
"n t c h w -> n c (t h w)"
)
l
,
r
=
x
.
quantile
(
q
=
torch
.
tensor
([
1
-
p
,
p
],
device
=
x
.
device
),
dim
=-
1
,
keepdim
=
True
)
s
=
torch
.
maximum
(
-
l
,
r
)
threshold_mask
=
(
s
>
1
).
expand
(
-
1
,
-
1
,
H
*
W
*
T
)
if
threshold_mask
.
any
():
x
=
torch
.
where
(
threshold_mask
,
x
.
clamp
(
min
=-
1
*
s
,
max
=
s
),
x
)
x
=
rearrange
(
x
,
"n c (t h w) -> n t c h w"
,
t
=
T
,
h
=
H
,
w
=
W
)
return
x
def
dynamic_thresholding2
(
x0
):
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
append_dims
(
torch
.
maximum
(
s
,
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
# / s
return
x0
.
to
(
origin_dtype
)
def
latent_dynamic_thresholding
(
x0
):
p
=
0.9995
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
),
p
,
dim
=
2
)
s
=
append_dims
(
s
,
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
return
x0
.
to
(
origin_dtype
)
def
dynamic_thresholding3
(
x0
):
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
origin_dtype
=
x0
.
dtype
x0
=
x0
.
to
(
torch
.
float32
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
append_dims
(
torch
.
maximum
(
s
,
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
x0
.
dim
())
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
# / s
return
x0
.
to
(
origin_dtype
)
class
DynamicThresholding
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
mean
=
uncond
.
mean
()
std
=
uncond
.
std
()
result
=
uncond
+
scale
*
(
cond
-
uncond
)
result_mean
,
result_std
=
result
.
mean
(),
result
.
std
()
result
=
(
result
-
result_mean
)
/
result_std
*
std
# result = dynamic_thresholding3(result)
return
result
class
DynamicThresholdingV1
:
def
__init__
(
self
,
scale_factor
):
self
.
scale_factor
=
scale_factor
def
__call__
(
self
,
uncond
,
cond
,
scale
):
result
=
uncond
+
scale
*
(
cond
-
uncond
)
unscaled_result
=
result
/
self
.
scale_factor
B
,
T
,
C
,
H
,
W
=
unscaled_result
.
shape
flattened
=
rearrange
(
unscaled_result
,
"b t c h w -> b c (t h w)"
)
means
=
flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
recentered
=
flattened
-
means
magnitudes
=
recentered
.
abs
().
max
()
normalized
=
recentered
/
magnitudes
thresholded
=
latent_dynamic_thresholding
(
normalized
)
denormalized
=
thresholded
*
magnitudes
uncentered
=
denormalized
+
means
unflattened
=
rearrange
(
uncentered
,
"b c (t h w) -> b t c h w"
,
t
=
T
,
h
=
H
,
w
=
W
)
scaled_result
=
unflattened
*
self
.
scale_factor
return
scaled_result
class
DynamicThresholdingV2
:
def
__call__
(
self
,
uncond
,
cond
,
scale
):
B
,
T
,
C
,
H
,
W
=
uncond
.
shape
diff
=
cond
-
uncond
mim_target
=
uncond
+
diff
*
4.0
cfg_target
=
uncond
+
diff
*
8.0
mim_flattened
=
rearrange
(
mim_target
,
"b t c h w -> b c (t h w)"
)
cfg_flattened
=
rearrange
(
cfg_target
,
"b t c h w -> b c (t h w)"
)
mim_means
=
mim_flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
cfg_means
=
cfg_flattened
.
mean
(
dim
=
2
).
unsqueeze
(
2
)
mim_centered
=
mim_flattened
-
mim_means
cfg_centered
=
cfg_flattened
-
cfg_means
mim_scaleref
=
mim_centered
.
std
(
dim
=
2
).
unsqueeze
(
2
)
cfg_scaleref
=
cfg_centered
.
std
(
dim
=
2
).
unsqueeze
(
2
)
cfg_renormalized
=
cfg_centered
/
cfg_scaleref
*
mim_scaleref
result
=
cfg_renormalized
+
cfg_means
unflattened
=
rearrange
(
result
,
"b c (t h w) -> b t c h w"
,
t
=
T
,
h
=
H
,
w
=
W
)
return
unflattened
def
linear_multistep_coeff
(
order
,
t
,
i
,
j
,
epsrel
=
1e-4
):
if
order
-
1
>
i
:
raise
ValueError
(
f
"Order
{
order
}
too high for step
{
i
}
"
)
def
fn
(
tau
):
prod
=
1.0
for
k
in
range
(
order
):
if
j
==
k
:
continue
prod
*=
(
tau
-
t
[
i
-
k
])
/
(
t
[
i
-
j
]
-
t
[
i
-
k
])
return
prod
return
integrate
.
quad
(
fn
,
t
[
i
],
t
[
i
+
1
],
epsrel
=
epsrel
)[
0
]
def
get_ancestral_step
(
sigma_from
,
sigma_to
,
eta
=
1.0
):
if
not
eta
:
return
sigma_to
,
0.0
sigma_up
=
torch
.
minimum
(
sigma_to
,
eta
*
(
sigma_to
**
2
*
(
sigma_from
**
2
-
sigma_to
**
2
)
/
sigma_from
**
2
)
**
0.5
,
)
sigma_down
=
(
sigma_to
**
2
-
sigma_up
**
2
)
**
0.5
return
sigma_down
,
sigma_up
def
to_d
(
x
,
sigma
,
denoised
):
return
(
x
-
denoised
)
/
append_dims
(
sigma
,
x
.
ndim
)
def
to_neg_log_sigma
(
sigma
):
return
sigma
.
log
().
neg
()
def
to_sigma
(
neg_log_sigma
):
return
neg_log_sigma
.
neg
().
exp
()
cogvideox-based/sat/sgm/modules/diffusionmodules/sigma_sampling.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.distributed
from
sat
import
mpu
from
...util
import
default
,
instantiate_from_config
class
EDMSampling
:
def
__init__
(
self
,
p_mean
=-
1.2
,
p_std
=
1.2
):
self
.
p_mean
=
p_mean
self
.
p_std
=
p_std
def
__call__
(
self
,
n_samples
,
rand
=
None
):
log_sigma
=
self
.
p_mean
+
self
.
p_std
*
default
(
rand
,
torch
.
randn
((
n_samples
,)))
return
log_sigma
.
exp
()
class
DiscreteSampling
:
def
__init__
(
self
,
discretization_config
,
num_idx
,
do_append_zero
=
False
,
flip
=
True
,
uniform_sampling
=
False
):
self
.
num_idx
=
num_idx
self
.
sigmas
=
instantiate_from_config
(
discretization_config
)(
num_idx
,
do_append_zero
=
do_append_zero
,
flip
=
flip
)
world_size
=
mpu
.
get_data_parallel_world_size
()
self
.
uniform_sampling
=
uniform_sampling
if
self
.
uniform_sampling
:
i
=
1
while
True
:
if
world_size
%
i
!=
0
or
num_idx
%
(
world_size
//
i
)
!=
0
:
i
+=
1
else
:
self
.
group_num
=
world_size
//
i
break
assert
self
.
group_num
>
0
assert
world_size
%
self
.
group_num
==
0
self
.
group_width
=
world_size
//
self
.
group_num
# the number of rank in one group
self
.
sigma_interval
=
self
.
num_idx
//
self
.
group_num
def
idx_to_sigma
(
self
,
idx
):
return
self
.
sigmas
[
idx
]
def
__call__
(
self
,
n_samples
,
rand
=
None
,
return_idx
=
False
):
if
self
.
uniform_sampling
:
rank
=
mpu
.
get_data_parallel_rank
()
group_index
=
rank
//
self
.
group_width
idx
=
default
(
rand
,
torch
.
randint
(
group_index
*
self
.
sigma_interval
,
(
group_index
+
1
)
*
self
.
sigma_interval
,
(
n_samples
,)
),
)
else
:
idx
=
default
(
rand
,
torch
.
randint
(
0
,
self
.
num_idx
,
(
n_samples
,)),
)
if
return_idx
:
return
self
.
idx_to_sigma
(
idx
),
idx
else
:
return
self
.
idx_to_sigma
(
idx
)
class
PartialDiscreteSampling
:
def
__init__
(
self
,
discretization_config
,
total_num_idx
,
partial_num_idx
,
do_append_zero
=
False
,
flip
=
True
):
self
.
total_num_idx
=
total_num_idx
self
.
partial_num_idx
=
partial_num_idx
self
.
sigmas
=
instantiate_from_config
(
discretization_config
)(
total_num_idx
,
do_append_zero
=
do_append_zero
,
flip
=
flip
)
def
idx_to_sigma
(
self
,
idx
):
return
self
.
sigmas
[
idx
]
def
__call__
(
self
,
n_samples
,
rand
=
None
):
idx
=
default
(
rand
,
# torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
torch
.
randint
(
0
,
self
.
partial_num_idx
,
(
n_samples
,)),
)
return
self
.
idx_to_sigma
(
idx
)
cogvideox-based/sat/sgm/modules/diffusionmodules/util.py
0 → 100644
View file @
1f5da520
"""
adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
and
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
thanks!
"""
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
):
if
schedule
==
"linear"
:
betas
=
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
return
betas
.
numpy
()
def
extract_into_tensor
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
mixed_checkpoint
(
func
,
inputs
:
dict
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
it also works with non-tensor inputs
:param func: the function to evaluate.
:param inputs: the argument dictionary to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
tensor_keys
=
[
key
for
key
in
inputs
if
isinstance
(
inputs
[
key
],
torch
.
Tensor
)]
tensor_inputs
=
[
inputs
[
key
]
for
key
in
inputs
if
isinstance
(
inputs
[
key
],
torch
.
Tensor
)]
non_tensor_keys
=
[
key
for
key
in
inputs
if
not
isinstance
(
inputs
[
key
],
torch
.
Tensor
)]
non_tensor_inputs
=
[
inputs
[
key
]
for
key
in
inputs
if
not
isinstance
(
inputs
[
key
],
torch
.
Tensor
)]
args
=
tuple
(
tensor_inputs
)
+
tuple
(
non_tensor_inputs
)
+
tuple
(
params
)
return
MixedCheckpointFunction
.
apply
(
func
,
len
(
tensor_inputs
),
len
(
non_tensor_inputs
),
tensor_keys
,
non_tensor_keys
,
*
args
,
)
else
:
return
func
(
**
inputs
)
class
MixedCheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length_tensors
,
length_non_tensors
,
tensor_keys
,
non_tensor_keys
,
*
args
,
):
ctx
.
end_tensors
=
length_tensors
ctx
.
end_non_tensors
=
length_tensors
+
length_non_tensors
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
(),
}
assert
len
(
tensor_keys
)
==
length_tensors
and
len
(
non_tensor_keys
)
==
length_non_tensors
ctx
.
input_tensors
=
{
key
:
val
for
(
key
,
val
)
in
zip
(
tensor_keys
,
list
(
args
[:
ctx
.
end_tensors
]))}
ctx
.
input_non_tensors
=
{
key
:
val
for
(
key
,
val
)
in
zip
(
non_tensor_keys
,
list
(
args
[
ctx
.
end_tensors
:
ctx
.
end_non_tensors
]))
}
ctx
.
run_function
=
run_function
ctx
.
input_params
=
list
(
args
[
ctx
.
end_non_tensors
:])
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
**
ctx
.
input_tensors
,
**
ctx
.
input_non_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
ctx
.
input_tensors
=
{
key
:
ctx
.
input_tensors
[
key
].
detach
().
requires_grad_
(
True
)
for
key
in
ctx
.
input_tensors
}
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
{
key
:
ctx
.
input_tensors
[
key
].
view_as
(
ctx
.
input_tensors
[
key
])
for
key
in
ctx
.
input_tensors
}
# shallow_copies.update(additional_args)
output_tensors
=
ctx
.
run_function
(
**
shallow_copies
,
**
ctx
.
input_non_tensors
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
list
(
ctx
.
input_tensors
.
values
())
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
(
(
None
,
None
,
None
,
None
,
None
)
+
input_grads
[:
ctx
.
end_tensors
]
+
(
None
,)
*
(
ctx
.
end_non_tensors
-
ctx
.
end_tensors
)
+
input_grads
[
ctx
.
end_tensors
:]
)
def
checkpoint
(
func
,
inputs
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
else
:
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
(),
}
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
ctx
.
input_tensors
=
[
x
.
detach
().
requires_grad_
(
True
)
for
x
in
ctx
.
input_tensors
]
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
ctx
.
input_tensors
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
(
None
,
None
)
+
input_grads
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
,
repeat_only
=
False
,
dtype
=
torch
.
float32
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
else
:
embedding
=
repeat
(
timesteps
,
"b -> b d"
,
d
=
dim
)
return
embedding
.
to
(
dtype
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
mean_flat
(
tensor
):
"""
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
32
,
channels
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
AlphaBlender
(
nn
.
Module
):
strategies
=
[
"learned"
,
"fixed"
,
"learned_with_images"
]
def
__init__
(
self
,
alpha
:
float
,
merge_strategy
:
str
=
"learned_with_images"
,
rearrange_pattern
:
str
=
"b t -> (b t) 1 1"
,
):
super
().
__init__
()
self
.
merge_strategy
=
merge_strategy
self
.
rearrange_pattern
=
rearrange_pattern
assert
merge_strategy
in
self
.
strategies
,
f
"merge_strategy needs to be in
{
self
.
strategies
}
"
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
"learned"
or
self
.
merge_strategy
==
"learned_with_images"
:
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
])))
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
get_alpha
(
self
,
image_only_indicator
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
merge_strategy
==
"fixed"
:
alpha
=
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
alpha
=
torch
.
sigmoid
(
self
.
mix_factor
)
elif
self
.
merge_strategy
==
"learned_with_images"
:
assert
image_only_indicator
is
not
None
,
"need image_only_indicator ..."
alpha
=
torch
.
where
(
image_only_indicator
.
bool
(),
torch
.
ones
(
1
,
1
,
device
=
image_only_indicator
.
device
),
rearrange
(
torch
.
sigmoid
(
self
.
mix_factor
),
"... -> ... 1"
),
)
alpha
=
rearrange
(
alpha
,
self
.
rearrange_pattern
)
else
:
raise
NotImplementedError
return
alpha
def
forward
(
self
,
x_spatial
:
torch
.
Tensor
,
x_temporal
:
torch
.
Tensor
,
image_only_indicator
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
alpha
=
self
.
get_alpha
(
image_only_indicator
)
x
=
alpha
.
to
(
x_spatial
.
dtype
)
*
x_spatial
+
(
1.0
-
alpha
).
to
(
x_spatial
.
dtype
)
*
x_temporal
return
x
cogvideox-based/sat/sgm/modules/diffusionmodules/wrappers.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
from
packaging
import
version
OPENAIUNETWRAPPER
=
"sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
class
IdentityWrapper
(
nn
.
Module
):
def
__init__
(
self
,
diffusion_model
,
compile_model
:
bool
=
False
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
super
().
__init__
()
compile
=
(
torch
.
compile
if
(
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"2.0.0"
))
and
compile_model
else
lambda
x
:
x
)
self
.
diffusion_model
=
compile
(
diffusion_model
)
self
.
dtype
=
dtype
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
diffusion_model
(
*
args
,
**
kwargs
)
class
OpenAIWrapper
(
IdentityWrapper
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
c
:
dict
,
**
kwargs
)
->
torch
.
Tensor
:
for
key
in
c
:
c
[
key
]
=
c
[
key
].
to
(
self
.
dtype
)
if
x
.
dim
()
==
4
:
x
=
torch
.
cat
((
x
,
c
.
get
(
"concat"
,
torch
.
Tensor
([]).
type_as
(
x
))),
dim
=
1
)
elif
x
.
dim
()
==
5
:
x
=
torch
.
cat
((
x
,
c
.
get
(
"concat"
,
torch
.
Tensor
([]).
type_as
(
x
))),
dim
=
2
)
else
:
raise
ValueError
(
"Input tensor must be 4D or 5D"
)
return
self
.
diffusion_model
(
x
,
timesteps
=
t
,
context
=
c
.
get
(
"crossattn"
,
None
),
y
=
c
.
get
(
"vector"
,
None
),
**
kwargs
,
)
cogvideox-based/sat/sgm/modules/distributions/__init__.py
0 → 100644
View file @
1f5da520
cogvideox-based/sat/sgm/modules/distributions/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/distributions/__pycache__/distributions.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/distributions/distributions.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
class
AbstractDistribution
:
def
sample
(
self
):
raise
NotImplementedError
()
def
mode
(
self
):
raise
NotImplementedError
()
class
DiracDistribution
(
AbstractDistribution
):
def
__init__
(
self
,
value
):
self
.
value
=
value
def
sample
(
self
):
return
self
.
value
def
mode
(
self
):
return
self
.
value
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
# device=self.parameters.device
# )
x
=
self
.
mean
+
self
.
std
*
torch
.
randn_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
def
mode
(
self
):
return
self
.
mean
def
normal_kl
(
mean1
,
logvar1
,
mean2
,
logvar2
):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor
=
None
for
obj
in
(
mean1
,
logvar1
,
mean2
,
logvar2
):
if
isinstance
(
obj
,
torch
.
Tensor
):
tensor
=
obj
break
assert
tensor
is
not
None
,
"at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1
,
logvar2
=
[
x
if
isinstance
(
x
,
torch
.
Tensor
)
else
torch
.
tensor
(
x
).
to
(
tensor
)
for
x
in
(
logvar1
,
logvar2
)]
return
0.5
*
(
-
1.0
+
logvar2
-
logvar1
+
torch
.
exp
(
logvar1
-
logvar2
)
+
((
mean1
-
mean2
)
**
2
)
*
torch
.
exp
(
-
logvar2
)
)
cogvideox-based/sat/sgm/modules/ema.py
0 → 100644
View file @
1f5da520
import
torch
from
torch
import
nn
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
"Decay must be between 0 and 1"
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
"decay"
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"num_updates"
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
),
)
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
"."
,
""
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
reset_num_updates
(
self
):
del
self
.
num_updates
self
.
register_buffer
(
"num_updates"
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
))
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
cogvideox-based/sat/sgm/modules/encoders/__init__.py
0 → 100644
View file @
1f5da520
cogvideox-based/sat/sgm/modules/encoders/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/encoders/__pycache__/modules.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/encoders/modules.py
0 → 100644
View file @
1f5da520
import
math
from
contextlib
import
nullcontext
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
kornia
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
omegaconf
import
ListConfig
from
torch.utils.checkpoint
import
checkpoint
from
transformers
import
(
T5EncoderModel
,
T5Tokenizer
,
)
from
...util
import
(
append_dims
,
autocast
,
count_params
,
default
,
disabled_train
,
expand_dims_like
,
instantiate_from_config
,
)
class
AbstractEmbModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_is_trainable
=
None
self
.
_ucg_rate
=
None
self
.
_input_key
=
None
@
property
def
is_trainable
(
self
)
->
bool
:
return
self
.
_is_trainable
@
property
def
ucg_rate
(
self
)
->
Union
[
float
,
torch
.
Tensor
]:
return
self
.
_ucg_rate
@
property
def
input_key
(
self
)
->
str
:
return
self
.
_input_key
@
is_trainable
.
setter
def
is_trainable
(
self
,
value
:
bool
):
self
.
_is_trainable
=
value
@
ucg_rate
.
setter
def
ucg_rate
(
self
,
value
:
Union
[
float
,
torch
.
Tensor
]):
self
.
_ucg_rate
=
value
@
input_key
.
setter
def
input_key
(
self
,
value
:
str
):
self
.
_input_key
=
value
@
is_trainable
.
deleter
def
is_trainable
(
self
):
del
self
.
_is_trainable
@
ucg_rate
.
deleter
def
ucg_rate
(
self
):
del
self
.
_ucg_rate
@
input_key
.
deleter
def
input_key
(
self
):
del
self
.
_input_key
class
GeneralConditioner
(
nn
.
Module
):
OUTPUT_DIM2KEYS
=
{
2
:
"vector"
,
3
:
"crossattn"
,
4
:
"concat"
,
5
:
"concat"
}
KEY2CATDIM
=
{
"vector"
:
1
,
"crossattn"
:
2
,
"concat"
:
1
}
def
__init__
(
self
,
emb_models
:
Union
[
List
,
ListConfig
],
cor_embs
=
[],
cor_p
=
[]):
super
().
__init__
()
embedders
=
[]
for
n
,
embconfig
in
enumerate
(
emb_models
):
embedder
=
instantiate_from_config
(
embconfig
)
assert
isinstance
(
embedder
,
AbstractEmbModel
),
f
"embedder model
{
embedder
.
__class__
.
__name__
}
has to inherit from AbstractEmbModel"
embedder
.
is_trainable
=
embconfig
.
get
(
"is_trainable"
,
False
)
embedder
.
ucg_rate
=
embconfig
.
get
(
"ucg_rate"
,
0.0
)
if
not
embedder
.
is_trainable
:
embedder
.
train
=
disabled_train
for
param
in
embedder
.
parameters
():
param
.
requires_grad
=
False
embedder
.
eval
()
print
(
f
"Initialized embedder #
{
n
}
:
{
embedder
.
__class__
.
__name__
}
"
f
"with
{
count_params
(
embedder
,
False
)
}
params. Trainable:
{
embedder
.
is_trainable
}
"
)
if
"input_key"
in
embconfig
:
embedder
.
input_key
=
embconfig
[
"input_key"
]
elif
"input_keys"
in
embconfig
:
embedder
.
input_keys
=
embconfig
[
"input_keys"
]
else
:
raise
KeyError
(
f
"need either 'input_key' or 'input_keys' for embedder
{
embedder
.
__class__
.
__name__
}
"
)
embedder
.
legacy_ucg_val
=
embconfig
.
get
(
"legacy_ucg_value"
,
None
)
if
embedder
.
legacy_ucg_val
is
not
None
:
embedder
.
ucg_prng
=
np
.
random
.
RandomState
()
embedders
.
append
(
embedder
)
self
.
embedders
=
nn
.
ModuleList
(
embedders
)
if
len
(
cor_embs
)
>
0
:
assert
len
(
cor_p
)
==
2
**
len
(
cor_embs
)
self
.
cor_embs
=
cor_embs
self
.
cor_p
=
cor_p
def
possibly_get_ucg_val
(
self
,
embedder
:
AbstractEmbModel
,
batch
:
Dict
)
->
Dict
:
assert
embedder
.
legacy_ucg_val
is
not
None
p
=
embedder
.
ucg_rate
val
=
embedder
.
legacy_ucg_val
for
i
in
range
(
len
(
batch
[
embedder
.
input_key
])):
if
embedder
.
ucg_prng
.
choice
(
2
,
p
=
[
1
-
p
,
p
]):
batch
[
embedder
.
input_key
][
i
]
=
val
return
batch
def
surely_get_ucg_val
(
self
,
embedder
:
AbstractEmbModel
,
batch
:
Dict
,
cond_or_not
)
->
Dict
:
assert
embedder
.
legacy_ucg_val
is
not
None
val
=
embedder
.
legacy_ucg_val
for
i
in
range
(
len
(
batch
[
embedder
.
input_key
])):
if
cond_or_not
[
i
]:
batch
[
embedder
.
input_key
][
i
]
=
val
return
batch
def
get_single_embedding
(
self
,
embedder
,
batch
,
output
,
cond_or_not
:
Optional
[
np
.
ndarray
]
=
None
,
force_zero_embeddings
:
Optional
[
List
]
=
None
,
):
embedding_context
=
nullcontext
if
embedder
.
is_trainable
else
torch
.
no_grad
with
embedding_context
():
if
hasattr
(
embedder
,
"input_key"
)
and
(
embedder
.
input_key
is
not
None
):
if
embedder
.
legacy_ucg_val
is
not
None
:
if
cond_or_not
is
None
:
batch
=
self
.
possibly_get_ucg_val
(
embedder
,
batch
)
else
:
batch
=
self
.
surely_get_ucg_val
(
embedder
,
batch
,
cond_or_not
)
emb_out
=
embedder
(
batch
[
embedder
.
input_key
])
elif
hasattr
(
embedder
,
"input_keys"
):
emb_out
=
embedder
(
*
[
batch
[
k
]
for
k
in
embedder
.
input_keys
])
assert
isinstance
(
emb_out
,
(
torch
.
Tensor
,
list
,
tuple
)
),
f
"encoder outputs must be tensors or a sequence, but got
{
type
(
emb_out
)
}
"
if
not
isinstance
(
emb_out
,
(
list
,
tuple
)):
emb_out
=
[
emb_out
]
for
emb
in
emb_out
:
out_key
=
self
.
OUTPUT_DIM2KEYS
[
emb
.
dim
()]
if
embedder
.
ucg_rate
>
0.0
and
embedder
.
legacy_ucg_val
is
None
:
if
cond_or_not
is
None
:
emb
=
(
expand_dims_like
(
torch
.
bernoulli
((
1.0
-
embedder
.
ucg_rate
)
*
torch
.
ones
(
emb
.
shape
[
0
],
device
=
emb
.
device
)),
emb
,
)
*
emb
)
else
:
emb
=
(
expand_dims_like
(
torch
.
tensor
(
1
-
cond_or_not
,
dtype
=
emb
.
dtype
,
device
=
emb
.
device
),
emb
,
)
*
emb
)
if
hasattr
(
embedder
,
"input_key"
)
and
embedder
.
input_key
in
force_zero_embeddings
:
emb
=
torch
.
zeros_like
(
emb
)
if
out_key
in
output
:
output
[
out_key
]
=
torch
.
cat
((
output
[
out_key
],
emb
),
self
.
KEY2CATDIM
[
out_key
])
else
:
output
[
out_key
]
=
emb
return
output
def
forward
(
self
,
batch
:
Dict
,
force_zero_embeddings
:
Optional
[
List
]
=
None
)
->
Dict
:
output
=
dict
()
if
force_zero_embeddings
is
None
:
force_zero_embeddings
=
[]
if
len
(
self
.
cor_embs
)
>
0
:
batch_size
=
len
(
batch
[
list
(
batch
.
keys
())[
0
]])
rand_idx
=
np
.
random
.
choice
(
len
(
self
.
cor_p
),
size
=
(
batch_size
,),
p
=
self
.
cor_p
)
for
emb_idx
in
self
.
cor_embs
:
cond_or_not
=
rand_idx
%
2
rand_idx
//=
2
output
=
self
.
get_single_embedding
(
self
.
embedders
[
emb_idx
],
batch
,
output
=
output
,
cond_or_not
=
cond_or_not
,
force_zero_embeddings
=
force_zero_embeddings
,
)
for
i
,
embedder
in
enumerate
(
self
.
embedders
):
if
i
in
self
.
cor_embs
:
continue
output
=
self
.
get_single_embedding
(
embedder
,
batch
,
output
=
output
,
force_zero_embeddings
=
force_zero_embeddings
)
return
output
def
get_unconditional_conditioning
(
self
,
batch_c
,
batch_uc
=
None
,
force_uc_zero_embeddings
=
None
):
if
force_uc_zero_embeddings
is
None
:
force_uc_zero_embeddings
=
[]
ucg_rates
=
list
()
for
embedder
in
self
.
embedders
:
ucg_rates
.
append
(
embedder
.
ucg_rate
)
embedder
.
ucg_rate
=
0.0
cor_embs
=
self
.
cor_embs
cor_p
=
self
.
cor_p
self
.
cor_embs
=
[]
self
.
cor_p
=
[]
c
=
self
(
batch_c
)
uc
=
self
(
batch_c
if
batch_uc
is
None
else
batch_uc
,
force_uc_zero_embeddings
)
for
embedder
,
rate
in
zip
(
self
.
embedders
,
ucg_rates
):
embedder
.
ucg_rate
=
rate
self
.
cor_embs
=
cor_embs
self
.
cor_p
=
cor_p
return
c
,
uc
class
FrozenT5Embedder
(
AbstractEmbModel
):
"""Uses the T5 transformer encoder for text"""
def
__init__
(
self
,
model_dir
=
"google/t5-v1_1-xxl"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
cache_dir
=
None
,
):
super
().
__init__
()
if
model_dir
is
not
"google/t5-v1_1-xxl"
:
self
.
tokenizer
=
T5Tokenizer
.
from_pretrained
(
model_dir
)
self
.
transformer
=
T5EncoderModel
.
from_pretrained
(
model_dir
)
else
:
self
.
tokenizer
=
T5Tokenizer
.
from_pretrained
(
model_dir
,
cache_dir
=
cache_dir
)
self
.
transformer
=
T5EncoderModel
.
from_pretrained
(
model_dir
,
cache_dir
=
cache_dir
)
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
# @autocast
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
with
torch
.
autocast
(
"cuda"
,
enabled
=
False
):
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
cogvideox-based/sat/sgm/modules/fuse_sft_block.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
import
torch.nn.functional
as
F
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
hidden_states
):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
class
InflatedConv3d
(
nn
.
Conv2d
):
def
forward
(
self
,
x
):
video_length
=
x
.
shape
[
2
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
super
().
forward
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
video_length
)
return
x
class
ResnetBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
temb_channels
=
512
,
groups
=
16
,
groups_out
=
None
,
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
use_in_shortcut
=
None
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
True
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
InflatedConv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
is
not
None
:
if
self
.
time_embedding_norm
==
"default"
:
time_emb_proj_out_channels
=
out_channels
elif
self
.
time_embedding_norm
==
"scale_shift"
:
time_emb_proj_out_channels
=
out_channels
*
2
else
:
raise
ValueError
(
f
"unknown time_embedding_norm :
{
self
.
time_embedding_norm
}
"
)
self
.
time_emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
time_emb_proj_out_channels
)
else
:
self
.
time_emb_proj
=
None
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
InflatedConv3d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
use_in_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_in_shortcut
is
None
else
use_in_shortcut
self
.
conv_shortcut
=
None
if
self
.
use_in_shortcut
:
self
.
conv_shortcut
=
InflatedConv3d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
input_tensor
,
temb
=
None
):
hidden_states
=
input_tensor
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
if
temb
is
not
None
:
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
,
None
]
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
hidden_states
=
self
.
norm2
(
hidden_states
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
self
.
conv_shortcut
(
input_tensor
)
output_tensor
=
(
input_tensor
+
hidden_states
)
/
self
.
output_scale_factor
return
output_tensor
class
Fuse_sft_block
(
nn
.
Module
):
def
__init__
(
self
,
enc_ch
,
dec_ch
):
super
().
__init__
()
self
.
shared
=
nn
.
Sequential
(
ResnetBlock3D
(
in_channels
=
enc_ch
+
dec_ch
,
out_channels
=
dec_ch
,
temb_channels
=
None
),
ResnetBlock3D
(
in_channels
=
dec_ch
,
out_channels
=
dec_ch
,
temb_channels
=
None
)
)
self
.
scale
=
nn
.
Conv3d
(
dec_ch
,
dec_ch
,
3
,
1
,
1
)
# InflatedConv3d(dec_ch, dec_ch, 3, 1, 1)
self
.
shift
=
nn
.
Conv3d
(
dec_ch
,
dec_ch
,
3
,
1
,
1
)
# InflatedConv3d(dec_ch, dec_ch, 3, 1, 1)
def
forward
(
self
,
enc_feat
,
dec_feat
,
w
=
1
):
enc_feat
=
self
.
shared
(
torch
.
cat
([
enc_feat
,
dec_feat
],
dim
=
1
))
scale
=
self
.
scale
(
enc_feat
)
shift
=
self
.
shift
(
enc_feat
)
residual
=
w
*
(
dec_feat
*
scale
+
shift
)
out
=
dec_feat
+
residual
return
out
if
__name__
==
"__main__"
:
block
=
Fuse_sft_block
(
16
,
16
)
enc_feat
=
torch
.
randn
(
1
,
16
,
4
,
60
,
90
)
dec_feat
=
torch
.
randn
(
1
,
16
,
4
,
60
,
90
)
out
=
block
(
enc_feat
,
dec_feat
)
print
(
out
.
shape
)
\ No newline at end of file
cogvideox-based/sat/sgm/modules/video_attention.py
0 → 100644
View file @
1f5da520
import
torch
from
..modules.attention
import
*
from
..modules.diffusionmodules.util
import
AlphaBlender
,
linear
,
timestep_embedding
class
TimeMixSequential
(
nn
.
Sequential
):
def
forward
(
self
,
x
,
context
=
None
,
timesteps
=
None
):
for
layer
in
self
:
x
=
layer
(
x
,
context
,
timesteps
)
return
x
class
VideoTransformerBlock
(
nn
.
Module
):
ATTENTION_MODES
=
{
"softmax"
:
CrossAttention
,
"softmax-xformers"
:
MemoryEfficientCrossAttention
,
}
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
timesteps
=
None
,
ff_in
=
False
,
inner_dim
=
None
,
attn_mode
=
"softmax"
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
switch_temporal_ca_to_sa
=
False
,
):
super
().
__init__
()
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
self
.
ff_in
=
ff_in
or
inner_dim
is
not
None
if
inner_dim
is
None
:
inner_dim
=
dim
assert
int
(
n_heads
*
d_head
)
==
inner_dim
self
.
is_res
=
inner_dim
==
dim
if
self
.
ff_in
:
self
.
norm_in
=
nn
.
LayerNorm
(
dim
)
self
.
ff_in
=
FeedForward
(
dim
,
dim_out
=
inner_dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
timesteps
=
timesteps
self
.
disable_self_attn
=
disable_self_attn
if
self
.
disable_self_attn
:
self
.
attn1
=
attn_cls
(
query_dim
=
inner_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
context_dim
=
context_dim
,
dropout
=
dropout
,
)
# is a cross-attention
else
:
self
.
attn1
=
attn_cls
(
query_dim
=
inner_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
inner_dim
,
dim_out
=
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
if
disable_temporal_crossattention
:
if
switch_temporal_ca_to_sa
:
raise
ValueError
else
:
self
.
attn2
=
None
else
:
self
.
norm2
=
nn
.
LayerNorm
(
inner_dim
)
if
switch_temporal_ca_to_sa
:
self
.
attn2
=
attn_cls
(
query_dim
=
inner_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
else
:
self
.
attn2
=
attn_cls
(
query_dim
=
inner_dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
inner_dim
)
self
.
norm3
=
nn
.
LayerNorm
(
inner_dim
)
self
.
switch_temporal_ca_to_sa
=
switch_temporal_ca_to_sa
self
.
checkpoint
=
checkpoint
if
self
.
checkpoint
:
print
(
f
"
{
self
.
__class__
.
__name__
}
is using checkpointing"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
torch
.
Tensor
=
None
,
timesteps
:
int
=
None
)
->
torch
.
Tensor
:
if
self
.
checkpoint
:
return
checkpoint
(
self
.
_forward
,
x
,
context
,
timesteps
)
else
:
return
self
.
_forward
(
x
,
context
,
timesteps
=
timesteps
)
def
_forward
(
self
,
x
,
context
=
None
,
timesteps
=
None
):
assert
self
.
timesteps
or
timesteps
assert
not
(
self
.
timesteps
and
timesteps
)
or
self
.
timesteps
==
timesteps
timesteps
=
self
.
timesteps
or
timesteps
B
,
S
,
C
=
x
.
shape
x
=
rearrange
(
x
,
"(b t) s c -> (b s) t c"
,
t
=
timesteps
)
if
self
.
ff_in
:
x_skip
=
x
x
=
self
.
ff_in
(
self
.
norm_in
(
x
))
if
self
.
is_res
:
x
+=
x_skip
if
self
.
disable_self_attn
:
x
=
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
)
+
x
else
:
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
if
self
.
attn2
is
not
None
:
if
self
.
switch_temporal_ca_to_sa
:
x
=
self
.
attn2
(
self
.
norm2
(
x
))
+
x
else
:
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x_skip
=
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
if
self
.
is_res
:
x
+=
x_skip
x
=
rearrange
(
x
,
"(b s) t c -> (b t) s c"
,
s
=
S
,
b
=
B
//
timesteps
,
c
=
C
,
t
=
timesteps
)
return
x
def
get_last_layer
(
self
):
return
self
.
ff
.
net
[
-
1
].
weight
str_to_dtype
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
class
SpatialVideoTransformer
(
SpatialTransformer
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
use_linear
=
False
,
context_dim
=
None
,
use_spatial_context
=
False
,
timesteps
=
None
,
merge_strategy
:
str
=
"fixed"
,
merge_factor
:
float
=
0.5
,
time_context_dim
=
None
,
ff_in
=
False
,
checkpoint
=
False
,
time_depth
=
1
,
attn_mode
=
"softmax"
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
max_time_embed_period
:
int
=
10000
,
dtype
=
"fp32"
,
):
super
().
__init__
(
in_channels
,
n_heads
,
d_head
,
depth
=
depth
,
dropout
=
dropout
,
attn_type
=
attn_mode
,
use_checkpoint
=
checkpoint
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
disable_self_attn
=
disable_self_attn
,
)
self
.
time_depth
=
time_depth
self
.
depth
=
depth
self
.
max_time_embed_period
=
max_time_embed_period
time_mix_d_head
=
d_head
n_time_mix_heads
=
n_heads
time_mix_inner_dim
=
int
(
time_mix_d_head
*
n_time_mix_heads
)
inner_dim
=
n_heads
*
d_head
if
use_spatial_context
:
time_context_dim
=
context_dim
self
.
time_stack
=
nn
.
ModuleList
(
[
VideoTransformerBlock
(
inner_dim
,
n_time_mix_heads
,
time_mix_d_head
,
dropout
=
dropout
,
context_dim
=
time_context_dim
,
timesteps
=
timesteps
,
checkpoint
=
checkpoint
,
ff_in
=
ff_in
,
inner_dim
=
time_mix_inner_dim
,
attn_mode
=
attn_mode
,
disable_self_attn
=
disable_self_attn
,
disable_temporal_crossattention
=
disable_temporal_crossattention
,
)
for
_
in
range
(
self
.
depth
)
]
)
assert
len
(
self
.
time_stack
)
==
len
(
self
.
transformer_blocks
)
self
.
use_spatial_context
=
use_spatial_context
self
.
in_channels
=
in_channels
time_embed_dim
=
self
.
in_channels
*
4
self
.
time_pos_embed
=
nn
.
Sequential
(
linear
(
self
.
in_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
self
.
in_channels
),
)
self
.
time_mixer
=
AlphaBlender
(
alpha
=
merge_factor
,
merge_strategy
=
merge_strategy
)
self
.
dtype
=
str_to_dtype
[
dtype
]
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
time_context
:
Optional
[
torch
.
Tensor
]
=
None
,
timesteps
:
Optional
[
int
]
=
None
,
image_only_indicator
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
_
,
_
,
h
,
w
=
x
.
shape
x_in
=
x
spatial_context
=
None
if
exists
(
context
):
spatial_context
=
context
if
self
.
use_spatial_context
:
assert
context
.
ndim
==
3
,
f
"n dims of spatial context should be 3 but are
{
context
.
ndim
}
"
time_context
=
context
time_context_first_timestep
=
time_context
[::
timesteps
]
time_context
=
repeat
(
time_context_first_timestep
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
elif
time_context
is
not
None
and
not
self
.
use_spatial_context
:
time_context
=
repeat
(
time_context
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
if
time_context
.
ndim
==
2
:
time_context
=
rearrange
(
time_context
,
"b c -> b 1 c"
)
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
num_frames
=
torch
.
arange
(
timesteps
,
device
=
x
.
device
)
num_frames
=
repeat
(
num_frames
,
"t -> b t"
,
b
=
x
.
shape
[
0
]
//
timesteps
)
num_frames
=
rearrange
(
num_frames
,
"b t -> (b t)"
)
t_emb
=
timestep_embedding
(
num_frames
,
self
.
in_channels
,
repeat_only
=
False
,
max_period
=
self
.
max_time_embed_period
,
dtype
=
self
.
dtype
,
)
emb
=
self
.
time_pos_embed
(
t_emb
)
emb
=
emb
[:,
None
,
:]
for
it_
,
(
block
,
mix_block
)
in
enumerate
(
zip
(
self
.
transformer_blocks
,
self
.
time_stack
)):
x
=
block
(
x
,
context
=
spatial_context
,
)
x_mix
=
x
x_mix
=
x_mix
+
emb
x_mix
=
mix_block
(
x_mix
,
context
=
time_context
,
timesteps
=
timesteps
)
x
=
self
.
time_mixer
(
x_spatial
=
x
,
x_temporal
=
x_mix
,
image_only_indicator
=
image_only_indicator
,
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
)
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
out
=
x
+
x_in
return
out
Prev
1
2
3
4
5
6
7
8
9
10
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment