Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
FlashVideo_pytorch
Commits
3b804999
Commit
3b804999
authored
Feb 20, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2420
failed with stages
in 0 seconds
Changes
146
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2426 additions
and
0 deletions
+2426
-0
flashvideo/sgm/models/__pycache__/__init__.cpython-310.pyc
flashvideo/sgm/models/__pycache__/__init__.cpython-310.pyc
+0
-0
flashvideo/sgm/models/__pycache__/autoencoder.cpython-310.pyc
...hvideo/sgm/models/__pycache__/autoencoder.cpython-310.pyc
+0
-0
flashvideo/sgm/models/autoencoder.py
flashvideo/sgm/models/autoencoder.py
+613
-0
flashvideo/sgm/modules/__init__.py
flashvideo/sgm/modules/__init__.py
+8
-0
flashvideo/sgm/modules/__pycache__/__init__.cpython-310.pyc
flashvideo/sgm/modules/__pycache__/__init__.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/__pycache__/attention.cpython-310.pyc
flashvideo/sgm/modules/__pycache__/attention.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/__pycache__/cp_enc_dec.cpython-310.pyc
...hvideo/sgm/modules/__pycache__/cp_enc_dec.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/__pycache__/ema.cpython-310.pyc
flashvideo/sgm/modules/__pycache__/ema.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/__pycache__/video_attention.cpython-310.pyc
...o/sgm/modules/__pycache__/video_attention.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/attention.py
flashvideo/sgm/modules/attention.py
+649
-0
flashvideo/sgm/modules/autoencoding/__init__.py
flashvideo/sgm/modules/autoencoding/__init__.py
+0
-0
flashvideo/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc
...modules/autoencoding/__pycache__/__init__.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc
...ules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/autoencoding/losses/__init__.py
flashvideo/sgm/modules/autoencoding/losses/__init__.py
+8
-0
flashvideo/sgm/modules/autoencoding/losses/discriminator_loss.py
...deo/sgm/modules/autoencoding/losses/discriminator_loss.py
+323
-0
flashvideo/sgm/modules/autoencoding/losses/lpips.py
flashvideo/sgm/modules/autoencoding/losses/lpips.py
+75
-0
flashvideo/sgm/modules/autoencoding/losses/video_loss.py
flashvideo/sgm/modules/autoencoding/losses/video_loss.py
+750
-0
flashvideo/sgm/modules/autoencoding/lpips/__init__.py
flashvideo/sgm/modules/autoencoding/lpips/__init__.py
+0
-0
flashvideo/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-310.pyc
...s/autoencoding/lpips/__pycache__/__init__.cpython-310.pyc
+0
-0
flashvideo/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-310.pyc
...dules/autoencoding/lpips/__pycache__/util.cpython-310.pyc
+0
-0
No files found.
flashvideo/sgm/models/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/models/__pycache__/autoencoder.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/models/autoencoder.py
0 → 100644
View file @
3b804999
import
logging
import
math
import
random
import
re
from
abc
import
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torch.distributed
import
torch.nn
as
nn
from
einops
import
rearrange
from
packaging
import
version
from
..modules.autoencoding.regularizers
import
AbstractRegularizer
from
..modules.cp_enc_dec
import
_conv_gather
,
_conv_split
from
..modules.ema
import
LitEma
from
..util
import
(
default
,
get_context_parallel_group
,
get_context_parallel_group_rank
,
get_nested_attribute
,
get_obj_from_str
,
initialize_context_parallel
,
instantiate_from_config
,
is_context_parallel_initialized
)
logpy
=
logging
.
getLogger
(
__name__
)
class
AbstractAutoencoder
(
pl
.
LightningModule
):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def
__init__
(
self
,
ema_decay
:
Union
[
None
,
float
]
=
None
,
monitor
:
Union
[
None
,
str
]
=
None
,
input_key
:
str
=
'jpg'
,
):
super
().
__init__
()
self
.
input_key
=
input_key
self
.
use_ema
=
ema_decay
is
not
None
if
monitor
is
not
None
:
self
.
monitor
=
monitor
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
,
decay
=
ema_decay
)
logpy
.
info
(
f
'Keeping EMAs of
{
len
(
list
(
self
.
model_ema
.
buffers
()))
}
.'
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
automatic_optimization
=
False
def
apply_ckpt
(
self
,
ckpt
:
Union
[
None
,
str
,
dict
]):
if
ckpt
is
None
:
return
if
isinstance
(
ckpt
,
str
):
ckpt
=
{
'target'
:
'sgm.modules.checkpoint.CheckpointEngine'
,
'params'
:
{
'ckpt_path'
:
ckpt
},
}
engine
=
instantiate_from_config
(
ckpt
)
engine
(
self
)
@
abstractmethod
def
get_input
(
self
,
batch
)
->
Any
:
raise
NotImplementedError
()
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
# for EMA computation
if
self
.
use_ema
:
self
.
model_ema
(
self
)
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
if
self
.
use_ema
:
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
copy_to
(
self
)
if
context
is
not
None
:
logpy
.
info
(
f
'
{
context
}
: Switched to EMA weights'
)
try
:
yield
None
finally
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
parameters
())
if
context
is
not
None
:
logpy
.
info
(
f
'
{
context
}
: Restored training weights'
)
@
abstractmethod
def
encode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
'encode()-method of abstract base class called'
)
@
abstractmethod
def
decode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
'decode()-method of abstract base class called'
)
def
instantiate_optimizer_from_config
(
self
,
params
,
lr
,
cfg
):
logpy
.
info
(
f
"loading >>>
{
cfg
[
'target'
]
}
<<< optimizer from config"
)
return
get_obj_from_str
(
cfg
[
'target'
])(
params
,
lr
=
lr
,
**
cfg
.
get
(
'params'
,
dict
()))
def
configure_optimizers
(
self
)
->
Any
:
raise
NotImplementedError
()
class
AutoencodingEngine
(
AbstractAutoencoder
):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def
__init__
(
self
,
*
args
,
encoder_config
:
Dict
,
decoder_config
:
Dict
,
loss_config
:
Dict
,
regularizer_config
:
Dict
,
optimizer_config
:
Union
[
Dict
,
None
]
=
None
,
lr_g_factor
:
float
=
1.0
,
trainable_ae_params
:
Optional
[
List
[
List
[
str
]]]
=
None
,
ae_optimizer_args
:
Optional
[
List
[
dict
]]
=
None
,
trainable_disc_params
:
Optional
[
List
[
List
[
str
]]]
=
None
,
disc_optimizer_args
:
Optional
[
List
[
dict
]]
=
None
,
disc_start_iter
:
int
=
0
,
diff_boost_factor
:
float
=
3.0
,
ckpt_engine
:
Union
[
None
,
str
,
dict
]
=
None
,
ckpt_path
:
Optional
[
str
]
=
None
,
additional_decode_keys
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
automatic_optimization
=
False
# pytorch lightning
self
.
encoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
encoder_config
)
self
.
decoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
decoder_config
)
self
.
loss
:
torch
.
nn
.
Module
=
instantiate_from_config
(
loss_config
)
self
.
regularization
:
AbstractRegularizer
=
instantiate_from_config
(
regularizer_config
)
self
.
optimizer_config
=
default
(
optimizer_config
,
{
'target'
:
'torch.optim.Adam'
})
self
.
diff_boost_factor
=
diff_boost_factor
self
.
disc_start_iter
=
disc_start_iter
self
.
lr_g_factor
=
lr_g_factor
self
.
trainable_ae_params
=
trainable_ae_params
if
self
.
trainable_ae_params
is
not
None
:
self
.
ae_optimizer_args
=
default
(
ae_optimizer_args
,
[{}
for
_
in
range
(
len
(
self
.
trainable_ae_params
))],
)
assert
len
(
self
.
ae_optimizer_args
)
==
len
(
self
.
trainable_ae_params
)
else
:
self
.
ae_optimizer_args
=
[{}]
# makes type consitent
self
.
trainable_disc_params
=
trainable_disc_params
if
self
.
trainable_disc_params
is
not
None
:
self
.
disc_optimizer_args
=
default
(
disc_optimizer_args
,
[{}
for
_
in
range
(
len
(
self
.
trainable_disc_params
))],
)
assert
len
(
self
.
disc_optimizer_args
)
==
len
(
self
.
trainable_disc_params
)
else
:
self
.
disc_optimizer_args
=
[{}]
# makes type consitent
if
ckpt_path
is
not
None
:
assert
ckpt_engine
is
None
,
"Can't set ckpt_engine and ckpt_path"
logpy
.
warn
(
'Checkpoint path is deprecated, use `checkpoint_egnine` instead'
)
self
.
apply_ckpt
(
default
(
ckpt_path
,
ckpt_engine
))
self
.
additional_decode_keys
=
set
(
default
(
additional_decode_keys
,
[]))
def
get_input
(
self
,
batch
:
Dict
)
->
torch
.
Tensor
:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return
batch
[
self
.
input_key
]
def
get_autoencoder_params
(
self
)
->
list
:
params
=
[]
if
hasattr
(
self
.
loss
,
'get_trainable_autoencoder_parameters'
):
params
+=
list
(
self
.
loss
.
get_trainable_autoencoder_parameters
())
if
hasattr
(
self
.
regularization
,
'get_trainable_parameters'
):
params
+=
list
(
self
.
regularization
.
get_trainable_parameters
())
params
=
params
+
list
(
self
.
encoder
.
parameters
())
params
=
params
+
list
(
self
.
decoder
.
parameters
())
return
params
def
get_discriminator_params
(
self
)
->
list
:
if
hasattr
(
self
.
loss
,
'get_trainable_parameters'
):
params
=
list
(
self
.
loss
.
get_trainable_parameters
())
# e.g., discriminator
else
:
params
=
[]
return
params
def
get_last_layer
(
self
):
return
self
.
decoder
.
get_last_layer
()
def
encode
(
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
,
unregularized
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
z
=
self
.
encoder
(
x
,
**
kwargs
)
if
unregularized
:
return
z
,
dict
()
z
,
reg_log
=
self
.
regularization
(
z
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
x
=
self
.
decoder
(
z
,
**
kwargs
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
additional_decode_kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
dict
]:
z
,
reg_log
=
self
.
encode
(
x
,
return_reg_log
=
True
)
dec
=
self
.
decode
(
z
,
**
additional_decode_kwargs
)
return
z
,
dec
,
reg_log
def
inner_training_step
(
self
,
batch
:
dict
,
batch_idx
:
int
,
optimizer_idx
:
int
=
0
)
->
torch
.
Tensor
:
x
=
self
.
get_input
(
batch
)
additional_decode_kwargs
=
{
key
:
batch
[
key
]
for
key
in
self
.
additional_decode_keys
.
intersection
(
batch
)
}
z
,
xrec
,
regularization_log
=
self
(
x
,
**
additional_decode_kwargs
)
if
hasattr
(
self
.
loss
,
'forward_keys'
):
extra_info
=
{
'z'
:
z
,
'optimizer_idx'
:
optimizer_idx
,
'global_step'
:
self
.
global_step
,
'last_layer'
:
self
.
get_last_layer
(),
'split'
:
'train'
,
'regularization_log'
:
regularization_log
,
'autoencoder'
:
self
,
}
extra_info
=
{
k
:
extra_info
[
k
]
for
k
in
self
.
loss
.
forward_keys
}
else
:
extra_info
=
dict
()
if
optimizer_idx
==
0
:
# autoencode
out_loss
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
if
isinstance
(
out_loss
,
tuple
):
aeloss
,
log_dict_ae
=
out_loss
else
:
# simple loss function
aeloss
=
out_loss
log_dict_ae
=
{
'train/loss/rec'
:
aeloss
.
detach
()}
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
,
sync_dist
=
False
,
)
self
.
log
(
'loss'
,
aeloss
.
mean
().
detach
(),
prog_bar
=
True
,
logger
=
False
,
on_epoch
=
False
,
on_step
=
True
,
)
return
aeloss
elif
optimizer_idx
==
1
:
# discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
# -> discriminator always needs to return a tuple
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
discloss
else
:
raise
NotImplementedError
(
f
'Unknown optimizer
{
optimizer_idx
}
'
)
def
training_step
(
self
,
batch
:
dict
,
batch_idx
:
int
):
opts
=
self
.
optimizers
()
if
not
isinstance
(
opts
,
list
):
# Non-adversarial case
opts
=
[
opts
]
optimizer_idx
=
batch_idx
%
len
(
opts
)
if
self
.
global_step
<
self
.
disc_start_iter
:
optimizer_idx
=
0
opt
=
opts
[
optimizer_idx
]
opt
.
zero_grad
()
with
opt
.
toggle_model
():
loss
=
self
.
inner_training_step
(
batch
,
batch_idx
,
optimizer_idx
=
optimizer_idx
)
self
.
manual_backward
(
loss
)
opt
.
step
()
def
validation_step
(
self
,
batch
:
dict
,
batch_idx
:
int
)
->
Dict
:
log_dict
=
self
.
_validation_step
(
batch
,
batch_idx
)
with
self
.
ema_scope
():
log_dict_ema
=
self
.
_validation_step
(
batch
,
batch_idx
,
postfix
=
'_ema'
)
log_dict
.
update
(
log_dict_ema
)
return
log_dict
def
_validation_step
(
self
,
batch
:
dict
,
batch_idx
:
int
,
postfix
:
str
=
''
)
->
Dict
:
x
=
self
.
get_input
(
batch
)
z
,
xrec
,
regularization_log
=
self
(
x
)
if
hasattr
(
self
.
loss
,
'forward_keys'
):
extra_info
=
{
'z'
:
z
,
'optimizer_idx'
:
0
,
'global_step'
:
self
.
global_step
,
'last_layer'
:
self
.
get_last_layer
(),
'split'
:
'val'
+
postfix
,
'regularization_log'
:
regularization_log
,
'autoencoder'
:
self
,
}
extra_info
=
{
k
:
extra_info
[
k
]
for
k
in
self
.
loss
.
forward_keys
}
else
:
extra_info
=
dict
()
out_loss
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
if
isinstance
(
out_loss
,
tuple
):
aeloss
,
log_dict_ae
=
out_loss
else
:
# simple loss function
aeloss
=
out_loss
log_dict_ae
=
{
f
'val
{
postfix
}
/loss/rec'
:
aeloss
.
detach
()}
full_log_dict
=
log_dict_ae
if
'optimizer_idx'
in
extra_info
:
extra_info
[
'optimizer_idx'
]
=
1
discloss
,
log_dict_disc
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
full_log_dict
.
update
(
log_dict_disc
)
self
.
log
(
f
'val
{
postfix
}
/loss/rec'
,
log_dict_ae
[
f
'val
{
postfix
}
/loss/rec'
],
sync_dist
=
True
,
)
self
.
log_dict
(
full_log_dict
,
sync_dist
=
True
)
return
full_log_dict
def
get_param_groups
(
self
,
parameter_names
:
List
[
List
[
str
]],
optimizer_args
:
List
[
dict
])
->
Tuple
[
List
[
Dict
[
str
,
Any
]],
int
]:
groups
=
[]
num_params
=
0
for
names
,
args
in
zip
(
parameter_names
,
optimizer_args
):
params
=
[]
for
pattern_
in
names
:
pattern_params
=
[]
pattern
=
re
.
compile
(
pattern_
)
for
p_name
,
param
in
self
.
named_parameters
():
if
re
.
match
(
pattern
,
p_name
):
pattern_params
.
append
(
param
)
num_params
+=
param
.
numel
()
if
len
(
pattern_params
)
==
0
:
logpy
.
warn
(
f
'Did not find parameters for pattern
{
pattern_
}
'
)
params
.
extend
(
pattern_params
)
groups
.
append
({
'params'
:
params
,
**
args
})
return
groups
,
num_params
def
configure_optimizers
(
self
)
->
List
[
torch
.
optim
.
Optimizer
]:
if
self
.
trainable_ae_params
is
None
:
ae_params
=
self
.
get_autoencoder_params
()
else
:
ae_params
,
num_ae_params
=
self
.
get_param_groups
(
self
.
trainable_ae_params
,
self
.
ae_optimizer_args
)
logpy
.
info
(
f
'Number of trainable autoencoder parameters:
{
num_ae_params
:,
}
'
)
if
self
.
trainable_disc_params
is
None
:
disc_params
=
self
.
get_discriminator_params
()
else
:
disc_params
,
num_disc_params
=
self
.
get_param_groups
(
self
.
trainable_disc_params
,
self
.
disc_optimizer_args
)
logpy
.
info
(
f
'Number of trainable discriminator parameters:
{
num_disc_params
:,
}
'
)
opt_ae
=
self
.
instantiate_optimizer_from_config
(
ae_params
,
default
(
self
.
lr_g_factor
,
1.0
)
*
self
.
learning_rate
,
self
.
optimizer_config
,
)
opts
=
[
opt_ae
]
if
len
(
disc_params
)
>
0
:
opt_disc
=
self
.
instantiate_optimizer_from_config
(
disc_params
,
self
.
learning_rate
,
self
.
optimizer_config
)
opts
.
append
(
opt_disc
)
return
opts
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
:
dict
,
additional_log_kwargs
:
Optional
[
Dict
]
=
None
,
**
kwargs
)
->
dict
:
log
=
dict
()
additional_decode_kwargs
=
{}
x
=
self
.
get_input
(
batch
)
additional_decode_kwargs
.
update
({
key
:
batch
[
key
]
for
key
in
self
.
additional_decode_keys
.
intersection
(
batch
)
})
_
,
xrec
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log
[
'inputs'
]
=
x
log
[
'reconstructions'
]
=
xrec
diff
=
0.5
*
torch
.
abs
(
torch
.
clamp
(
xrec
,
-
1.0
,
1.0
)
-
x
)
diff
.
clamp_
(
0
,
1.0
)
log
[
'diff'
]
=
2.0
*
diff
-
1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log
[
'diff_boost'
]
=
2.0
*
torch
.
clamp
(
self
.
diff_boost_factor
*
diff
,
0.0
,
1.0
)
-
1
if
hasattr
(
self
.
loss
,
'log_images'
):
log
.
update
(
self
.
loss
.
log_images
(
x
,
xrec
))
with
self
.
ema_scope
():
_
,
xrec_ema
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log
[
'reconstructions_ema'
]
=
xrec_ema
diff_ema
=
0.5
*
torch
.
abs
(
torch
.
clamp
(
xrec_ema
,
-
1.0
,
1.0
)
-
x
)
diff_ema
.
clamp_
(
0
,
1.0
)
log
[
'diff_ema'
]
=
2.0
*
diff_ema
-
1.0
log
[
'diff_boost_ema'
]
=
2.0
*
torch
.
clamp
(
self
.
diff_boost_factor
*
diff_ema
,
0.0
,
1.0
)
-
1
if
additional_log_kwargs
:
additional_decode_kwargs
.
update
(
additional_log_kwargs
)
_
,
xrec_add
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log_str
=
'reconstructions-'
+
'-'
.
join
([
f
'
{
key
}
=
{
additional_log_kwargs
[
key
]
}
'
for
key
in
additional_log_kwargs
])
log
[
log_str
]
=
xrec_add
return
log
class
AutoencodingEngineLegacy
(
AutoencodingEngine
):
def
__init__
(
self
,
embed_dim
:
int
,
**
kwargs
):
self
.
max_batch_size
=
kwargs
.
pop
(
'max_batch_size'
,
None
)
ddconfig
=
kwargs
.
pop
(
'ddconfig'
)
ckpt_path
=
kwargs
.
pop
(
'ckpt_path'
,
None
)
ckpt_engine
=
kwargs
.
pop
(
'ckpt_engine'
,
None
)
super
().
__init__
(
encoder_config
=
{
'target'
:
'sgm.modules.diffusionmodules.model.Encoder'
,
'params'
:
ddconfig
,
},
decoder_config
=
{
'target'
:
'sgm.modules.diffusionmodules.model.Decoder'
,
'params'
:
ddconfig
,
},
**
kwargs
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
(
1
+
ddconfig
[
'double_z'
])
*
ddconfig
[
'z_channels'
],
(
1
+
ddconfig
[
'double_z'
])
*
embed_dim
,
1
,
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
'z_channels'
],
1
)
self
.
embed_dim
=
embed_dim
self
.
apply_ckpt
(
default
(
ckpt_path
,
ckpt_engine
))
def
get_autoencoder_params
(
self
)
->
list
:
params
=
super
().
get_autoencoder_params
()
return
params
def
encode
(
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
if
self
.
max_batch_size
is
None
:
z
=
self
.
encoder
(
x
)
z
=
self
.
quant_conv
(
z
)
else
:
N
=
x
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
z
=
list
()
for
i_batch
in
range
(
n_batches
):
z_batch
=
self
.
encoder
(
x
[
i_batch
*
bs
:(
i_batch
+
1
)
*
bs
])
z_batch
=
self
.
quant_conv
(
z_batch
)
z
.
append
(
z_batch
)
z
=
torch
.
cat
(
z
,
0
)
z
,
reg_log
=
self
.
regularization
(
z
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
decoder_kwargs
)
->
torch
.
Tensor
:
if
self
.
max_batch_size
is
None
:
dec
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
dec
,
**
decoder_kwargs
)
else
:
N
=
z
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
dec
=
list
()
for
i_batch
in
range
(
n_batches
):
dec_batch
=
self
.
post_quant_conv
(
z
[
i_batch
*
bs
:(
i_batch
+
1
)
*
bs
])
dec_batch
=
self
.
decoder
(
dec_batch
,
**
decoder_kwargs
)
dec
.
append
(
dec_batch
)
dec
=
torch
.
cat
(
dec
,
0
)
return
dec
class
IdentityFirstStage
(
AbstractAutoencoder
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
get_input
(
self
,
x
:
Any
)
->
Any
:
return
x
def
encode
(
self
,
x
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
return
x
def
decode
(
self
,
x
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
return
import
os
class
VideoAutoencodingEngine
(
AutoencodingEngine
):
def
__init__
(
self
,
ckpt_path
:
Union
[
None
,
str
]
=
None
,
ignore_keys
:
Union
[
Tuple
,
list
]
=
(),
image_video_weights
=
[
1
,
1
],
only_train_decoder
=
False
,
context_parallel_size
=
0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
context_parallel_size
=
context_parallel_size
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
def
log_videos
(
self
,
batch
:
dict
,
additional_log_kwargs
:
Optional
[
Dict
]
=
None
,
**
kwargs
)
->
dict
:
return
self
.
log_images
(
batch
,
additional_log_kwargs
,
**
kwargs
)
def
get_input
(
self
,
batch
:
dict
)
->
torch
.
Tensor
:
if
self
.
context_parallel_size
>
0
:
if
not
is_context_parallel_initialized
():
initialize_context_parallel
(
self
.
context_parallel_size
)
batch
=
batch
[
self
.
input_key
]
global_src_rank
=
get_context_parallel_group_rank
(
)
*
self
.
context_parallel_size
torch
.
distributed
.
broadcast
(
batch
,
src
=
global_src_rank
,
group
=
get_context_parallel_group
())
batch
=
_conv_split
(
batch
,
dim
=
2
,
kernel_size
=
1
)
return
batch
return
batch
[
self
.
input_key
]
def
apply_ckpt
(
self
,
ckpt
:
Union
[
None
,
str
,
dict
]):
if
ckpt
is
None
:
return
self
.
init_from_ckpt
(
ckpt
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
if
os
.
environ
.
get
(
'SKIP_LOAD'
,
False
):
print
(
f
'skip loading from
{
path
}
'
)
return
sd
=
torch
.
load
(
path
,
map_location
=
'cpu'
)[
'state_dict'
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
del
sd
[
k
]
missing_keys
,
unexpected_keys
=
self
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
'Missing keys: '
,
missing_keys
)
print
(
'Unexpected keys: '
,
unexpected_keys
)
print
(
f
'Restored from
{
path
}
'
)
flashvideo/sgm/modules/__init__.py
0 → 100644
View file @
3b804999
from
.encoders.modules
import
GeneralConditioner
UNCONDITIONAL_CONFIG
=
{
'target'
:
'sgm.modules.GeneralConditioner'
,
'params'
:
{
'emb_models'
:
[]
},
}
flashvideo/sgm/modules/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/__pycache__/attention.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/__pycache__/cp_enc_dec.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/__pycache__/ema.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/__pycache__/video_attention.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/attention.py
0 → 100644
View file @
3b804999
import
math
from
inspect
import
isfunction
from
typing
import
Any
,
Optional
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
packaging
import
version
from
torch
import
nn
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
SDP_IS_AVAILABLE
=
True
from
torch.backends.cuda
import
SDPBackend
,
sdp_kernel
BACKEND_MAP
=
{
SDPBackend
.
MATH
:
{
'enable_math'
:
True
,
'enable_flash'
:
False
,
'enable_mem_efficient'
:
False
,
},
SDPBackend
.
FLASH_ATTENTION
:
{
'enable_math'
:
False
,
'enable_flash'
:
True
,
'enable_mem_efficient'
:
False
,
},
SDPBackend
.
EFFICIENT_ATTENTION
:
{
'enable_math'
:
False
,
'enable_flash'
:
False
,
'enable_mem_efficient'
:
True
,
},
None
:
{
'enable_math'
:
True
,
'enable_flash'
:
True
,
'enable_mem_efficient'
:
True
},
}
else
:
from
contextlib
import
nullcontext
SDP_IS_AVAILABLE
=
False
sdp_kernel
=
nullcontext
BACKEND_MAP
=
{}
print
(
f
'No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, '
f
'you are using PyTorch
{
torch
.
__version__
}
. You might want to consider upgrading.'
)
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILABLE
=
True
except
:
XFORMERS_IS_AVAILABLE
=
False
print
(
"no module 'xformers'. Processing without..."
)
from
.diffusionmodules.util
import
checkpoint
def
exists
(
val
):
return
val
is
not
None
def
uniq
(
arr
):
return
{
el
:
True
for
el
in
arr
}.
keys
()
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
max_neg_value
(
t
):
return
-
torch
.
finfo
(
t
.
dtype
).
max
def
init_
(
tensor
):
dim
=
tensor
.
shape
[
-
1
]
std
=
1
/
math
.
sqrt
(
dim
)
tensor
.
uniform_
(
-
std
,
std
)
return
tensor
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
().
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
SpatialSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
rearrange
(
q
,
'b c h w -> b (h w) c'
)
k
=
rearrange
(
k
,
'b c h w -> b c (h w)'
)
w_
=
torch
.
einsum
(
'bij,bjk->bik'
,
q
,
k
)
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
rearrange
(
v
,
'b c h w -> b c (h w)'
)
w_
=
rearrange
(
w_
,
'b i j -> b j i'
)
h_
=
torch
.
einsum
(
'bij,bjk->bik'
,
v
,
w_
)
h_
=
rearrange
(
h_
,
'b c (h w) -> b c h w'
,
h
=
h
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
backend
=
None
,
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
backend
=
backend
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
additional_tokens
=
None
,
n_times_crossframe_attn_in_self
=
0
,
):
h
=
self
.
heads
if
additional_tokens
is
not
None
:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask
=
additional_tokens
.
shape
[
1
]
# add additional token
x
=
torch
.
cat
([
additional_tokens
,
x
],
dim
=
1
)
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
n_times_crossframe_attn_in_self
:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert
x
.
shape
[
0
]
%
n_times_crossframe_attn_in_self
==
0
n_cp
=
x
.
shape
[
0
]
//
n_times_crossframe_attn_in_self
k
=
repeat
(
k
[::
n_times_crossframe_attn_in_self
],
'b ... -> (b n) ...'
,
n
=
n_cp
)
v
=
repeat
(
v
[::
n_times_crossframe_attn_in_self
],
'b ... -> (b n) ...'
,
n
=
n_cp
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> b h n d'
,
h
=
h
),
(
q
,
k
,
v
))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with
sdp_kernel
(
**
BACKEND_MAP
[
self
.
backend
]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
mask
)
# scale is dim_head ** -0.5 per default
del
q
,
k
,
v
out
=
rearrange
(
out
,
'b h n d -> b n (h d)'
,
h
=
h
)
if
additional_tokens
is
not
None
:
# remove additional token
out
=
out
[:,
n_tokens_to_mask
:]
return
self
.
to_out
(
out
)
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
**
kwargs
):
super
().
__init__
()
print
(
f
'Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using '
f
'
{
heads
}
heads with a dimension of
{
dim_head
}
.'
)
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
additional_tokens
=
None
,
n_times_crossframe_attn_in_self
=
0
,
):
if
additional_tokens
is
not
None
:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask
=
additional_tokens
.
shape
[
1
]
# add additional token
x
=
torch
.
cat
([
additional_tokens
,
x
],
dim
=
1
)
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
n_times_crossframe_attn_in_self
:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert
x
.
shape
[
0
]
%
n_times_crossframe_attn_in_self
==
0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k
=
repeat
(
k
[::
n_times_crossframe_attn_in_self
],
'b ... -> (b n) ...'
,
n
=
n_times_crossframe_attn_in_self
,
)
v
=
repeat
(
v
[::
n_times_crossframe_attn_in_self
],
'b ... -> (b n) ...'
,
n
=
n_times_crossframe_attn_in_self
,
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
).
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
).
contiguous
(),
(
q
,
k
,
v
),
)
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
# TODO: Use this directly in the attention operation, as a bias
if
exists
(
mask
):
raise
NotImplementedError
out
=
(
out
.
unsqueeze
(
0
).
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
))
if
additional_tokens
is
not
None
:
# remove additional token
out
=
out
[:,
n_tokens_to_mask
:]
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
ATTENTION_MODES
=
{
'softmax'
:
CrossAttention
,
# vanilla attention
'softmax-xformers'
:
MemoryEfficientCrossAttention
,
# ampere
}
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
attn_mode
=
'softmax'
,
sdp_backend
=
None
,
):
super
().
__init__
()
assert
attn_mode
in
self
.
ATTENTION_MODES
if
attn_mode
!=
'softmax'
and
not
XFORMERS_IS_AVAILABLE
:
print
(
f
"Attention mode '
{
attn_mode
}
' is not available. Falling back to native attention. "
f
'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version
{
torch
.
__version__
}
'
)
attn_mode
=
'softmax'
elif
attn_mode
==
'softmax'
and
not
SDP_IS_AVAILABLE
:
print
(
'We do not support vanilla attention anymore, as it is too expensive. Sorry.'
)
if
not
XFORMERS_IS_AVAILABLE
:
assert
False
,
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
else
:
print
(
'Falling back to xformers efficient attention.'
)
attn_mode
=
'softmax-xformers'
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
assert
sdp_backend
is
None
or
isinstance
(
sdp_backend
,
SDPBackend
)
else
:
assert
sdp_backend
is
None
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
attn_cls
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
backend
=
sdp_backend
,
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
attn_cls
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
backend
=
sdp_backend
,
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
if
self
.
checkpoint
:
print
(
f
'
{
self
.
__class__
.
__name__
}
is using checkpointing'
)
def
forward
(
self
,
x
,
context
=
None
,
additional_tokens
=
None
,
n_times_crossframe_attn_in_self
=
0
):
kwargs
=
{
'x'
:
x
}
if
context
is
not
None
:
kwargs
.
update
({
'context'
:
context
})
if
additional_tokens
is
not
None
:
kwargs
.
update
({
'additional_tokens'
:
additional_tokens
})
if
n_times_crossframe_attn_in_self
:
kwargs
.
update
({
'n_times_crossframe_attn_in_self'
:
n_times_crossframe_attn_in_self
})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return
checkpoint
(
self
.
_forward
,
(
x
,
context
),
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
,
additional_tokens
=
None
,
n_times_crossframe_attn_in_self
=
0
):
x
=
(
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
if
self
.
disable_self_attn
else
None
,
additional_tokens
=
additional_tokens
,
n_times_crossframe_attn_in_self
=
n_times_crossframe_attn_in_self
if
not
self
.
disable_self_attn
else
0
,
)
+
x
)
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
,
additional_tokens
=
additional_tokens
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
class
BasicTransformerSingleLayerBlock
(
nn
.
Module
):
ATTENTION_MODES
=
{
'softmax'
:
CrossAttention
,
# vanilla attention
'softmax-xformers'
:
MemoryEfficientCrossAttention
,
# on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
attn_mode
=
'softmax'
,
):
super
().
__init__
()
assert
attn_mode
in
self
.
ATTENTION_MODES
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
self
.
attn1
=
attn_cls
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
,
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
),
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
):
x
=
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm2
(
x
))
+
x
return
x
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
attn_type
=
'softmax'
,
use_checkpoint
=
True
,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend
=
None
,
):
super
().
__init__
()
print
(
f
'constructing
{
self
.
__class__
.
__name__
}
of depth
{
depth
}
w/
{
in_channels
}
channels and
{
n_heads
}
heads'
)
from
omegaconf
import
ListConfig
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
(
list
,
ListConfig
)):
context_dim
=
[
context_dim
]
if
exists
(
context_dim
)
and
isinstance
(
context_dim
,
list
):
if
depth
!=
len
(
context_dim
):
print
(
f
'WARNING:
{
self
.
__class__
.
__name__
}
: Found context dims
{
context_dim
}
of depth
{
len
(
context_dim
)
}
, '
f
"which does not match the specified 'depth' of
{
depth
}
. Setting context_dim to
{
depth
*
[
context_dim
[
0
]]
}
now."
)
# depth does not match context dims.
assert
all
(
map
(
lambda
x
:
x
==
context_dim
[
0
],
context_dim
)
),
'need homogenous context_dim to match depth automatically'
context_dim
=
depth
*
[
context_dim
[
0
]]
elif
context_dim
is
None
:
context_dim
=
[
None
]
*
depth
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
self
.
transformer_blocks
=
nn
.
ModuleList
([
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
attn_mode
=
attn_type
,
checkpoint
=
use_checkpoint
,
sdp_backend
=
sdp_backend
,
)
for
d
in
range
(
depth
)
])
if
not
use_linear
:
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
else
:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
inner_dim
,
in_channels
))
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
if
not
isinstance
(
context
,
list
):
context
=
[
context
]
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
).
contiguous
()
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
i
>
0
and
len
(
context
)
==
1
:
i
=
0
# use same context for each block
x
=
block
(
x
,
context
=
context
[
i
])
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
).
contiguous
()
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
flashvideo/sgm/modules/autoencoding/__init__.py
0 → 100644
View file @
3b804999
flashvideo/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/autoencoding/losses/__init__.py
0 → 100644
View file @
3b804999
__all__
=
[
'GeneralLPIPSWithDiscriminator'
,
'LatentLPIPS'
,
]
from
.discriminator_loss
import
GeneralLPIPSWithDiscriminator
from
.lpips
import
LatentLPIPS
from
.video_loss
import
VideoAutoencoderLoss
flashvideo/sgm/modules/autoencoding/losses/discriminator_loss.py
0 → 100644
View file @
3b804999
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torchvision
from
einops
import
rearrange
from
matplotlib
import
colormaps
from
matplotlib
import
pyplot
as
plt
from
....util
import
default
,
instantiate_from_config
from
..lpips.loss.lpips
import
LPIPS
from
..lpips.model.model
import
weights_init
from
..lpips.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
class
GeneralLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
:
int
,
logvar_init
:
float
=
0.0
,
disc_num_layers
:
int
=
3
,
disc_in_channels
:
int
=
3
,
disc_factor
:
float
=
1.0
,
disc_weight
:
float
=
1.0
,
perceptual_weight
:
float
=
1.0
,
disc_loss
:
str
=
'hinge'
,
scale_input_to_tgt_size
:
bool
=
False
,
dims
:
int
=
2
,
learn_logvar
:
bool
=
False
,
regularization_weights
:
Union
[
None
,
Dict
[
str
,
float
]]
=
None
,
additional_log_keys
:
Optional
[
List
[
str
]]
=
None
,
discriminator_config
:
Optional
[
Dict
]
=
None
,
):
super
().
__init__
()
self
.
dims
=
dims
if
self
.
dims
>
2
:
print
(
f
'running with dims=
{
dims
}
. This means that for perceptual loss '
f
'calculation, the LPIPS loss will be applied to each frame '
f
'independently.'
)
self
.
scale_input_to_tgt_size
=
scale_input_to_tgt_size
assert
disc_loss
in
[
'hinge'
,
'vanilla'
]
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
full
((),
logvar_init
),
requires_grad
=
learn_logvar
)
self
.
learn_logvar
=
learn_logvar
discriminator_config
=
default
(
discriminator_config
,
{
'target'
:
'sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator'
,
'params'
:
{
'input_nc'
:
disc_in_channels
,
'n_layers'
:
disc_num_layers
,
'use_actnorm'
:
False
,
},
},
)
self
.
discriminator
=
instantiate_from_config
(
discriminator_config
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
'hinge'
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
regularization_weights
=
default
(
regularization_weights
,
{})
self
.
forward_keys
=
[
'optimizer_idx'
,
'global_step'
,
'last_layer'
,
'split'
,
'regularization_log'
,
]
self
.
additional_log_keys
=
set
(
default
(
additional_log_keys
,
[]))
self
.
additional_log_keys
.
update
(
set
(
self
.
regularization_weights
.
keys
()))
def
get_trainable_parameters
(
self
)
->
Iterator
[
nn
.
Parameter
]:
return
self
.
discriminator
.
parameters
()
def
get_trainable_autoencoder_parameters
(
self
)
->
Iterator
[
nn
.
Parameter
]:
if
self
.
learn_logvar
:
yield
self
.
logvar
yield
from
()
@
torch
.
no_grad
()
def
log_images
(
self
,
inputs
:
torch
.
Tensor
,
reconstructions
:
torch
.
Tensor
)
->
Dict
[
str
,
torch
.
Tensor
]:
# calc logits of real/fake
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
if
len
(
logits_real
.
shape
)
<
4
:
# Non patch-discriminator
return
dict
()
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
# -> (b, 1, h, w)
# parameters for colormapping
high
=
max
(
logits_fake
.
abs
().
max
(),
logits_real
.
abs
().
max
()).
item
()
cmap
=
colormaps
[
'PiYG'
]
# diverging colormap
def
to_colormap
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""(b, 1, ...) -> (b, 3, ...)"""
logits
=
(
logits
+
high
)
/
(
2
*
high
)
logits_np
=
cmap
(
logits
.
cpu
().
numpy
())[...,
:
3
]
# truncate alpha channel
# -> (b, 1, ..., 3)
logits
=
torch
.
from_numpy
(
logits_np
).
to
(
logits
.
device
)
return
rearrange
(
logits
,
'b 1 ... c -> b c ...'
)
logits_real
=
torch
.
nn
.
functional
.
interpolate
(
logits_real
,
size
=
inputs
.
shape
[
-
2
:],
mode
=
'nearest'
,
antialias
=
False
,
)
logits_fake
=
torch
.
nn
.
functional
.
interpolate
(
logits_fake
,
size
=
reconstructions
.
shape
[
-
2
:],
mode
=
'nearest'
,
antialias
=
False
,
)
# alpha value of logits for overlay
alpha_real
=
torch
.
abs
(
logits_real
)
/
high
alpha_fake
=
torch
.
abs
(
logits_fake
)
/
high
# -> (b, 1, h, w) in range [0, 0.5]
# alpha value of lines don't really matter, since the values are the same
# for both images and logits anyway
grid_alpha_real
=
torchvision
.
utils
.
make_grid
(
alpha_real
,
nrow
=
4
)
grid_alpha_fake
=
torchvision
.
utils
.
make_grid
(
alpha_fake
,
nrow
=
4
)
grid_alpha
=
0.8
*
torch
.
cat
((
grid_alpha_real
,
grid_alpha_fake
),
dim
=
1
)
# -> (1, h, w)
# blend logits and images together
# prepare logits for plotting
logits_real
=
to_colormap
(
logits_real
)
logits_fake
=
to_colormap
(
logits_fake
)
# resize logits
# -> (b, 3, h, w)
# make some grids
# add all logits to one plot
logits_real
=
torchvision
.
utils
.
make_grid
(
logits_real
,
nrow
=
4
)
logits_fake
=
torchvision
.
utils
.
make_grid
(
logits_fake
,
nrow
=
4
)
# I just love how torchvision calls the number of columns `nrow`
grid_logits
=
torch
.
cat
((
logits_real
,
logits_fake
),
dim
=
1
)
# -> (3, h, w)
grid_images_real
=
torchvision
.
utils
.
make_grid
(
0.5
*
inputs
+
0.5
,
nrow
=
4
)
grid_images_fake
=
torchvision
.
utils
.
make_grid
(
0.5
*
reconstructions
+
0.5
,
nrow
=
4
)
grid_images
=
torch
.
cat
((
grid_images_real
,
grid_images_fake
),
dim
=
1
)
# -> (3, h, w) in range [0, 1]
grid_blend
=
grid_alpha
*
grid_logits
+
(
1
-
grid_alpha
)
*
grid_images
# Create labeled colorbar
dpi
=
100
height
=
128
/
dpi
width
=
grid_logits
.
shape
[
2
]
/
dpi
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
width
,
height
),
dpi
=
dpi
)
img
=
ax
.
imshow
(
np
.
array
([[
-
high
,
high
]]),
cmap
=
cmap
)
plt
.
colorbar
(
img
,
cax
=
ax
,
orientation
=
'horizontal'
,
fraction
=
0.9
,
aspect
=
width
/
height
,
pad
=
0.0
,
)
img
.
set_visible
(
False
)
fig
.
tight_layout
()
fig
.
canvas
.
draw
()
# manually convert figure to numpy
cbar_np
=
np
.
frombuffer
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
)
cbar_np
=
cbar_np
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,
))
cbar
=
torch
.
from_numpy
(
cbar_np
.
copy
()).
to
(
grid_logits
.
dtype
)
/
255.0
cbar
=
rearrange
(
cbar
,
'h w c -> c h w'
).
to
(
grid_logits
.
device
)
# Add colorbar to plot
annotated_grid
=
torch
.
cat
((
grid_logits
,
cbar
),
dim
=
1
)
blended_grid
=
torch
.
cat
((
grid_blend
,
cbar
),
dim
=
1
)
return
{
'vis_logits'
:
2
*
annotated_grid
[
None
,
...]
-
1
,
'vis_logits_blended'
:
2
*
blended_grid
[
None
,
...]
-
1
,
}
def
calculate_adaptive_weight
(
self
,
nll_loss
:
torch
.
Tensor
,
g_loss
:
torch
.
Tensor
,
last_layer
:
torch
.
Tensor
)
->
torch
.
Tensor
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
reconstructions
:
torch
.
Tensor
,
*
,
# added because I changed the order here
regularization_log
:
Dict
[
str
,
torch
.
Tensor
],
optimizer_idx
:
int
,
global_step
:
int
,
last_layer
:
torch
.
Tensor
,
split
:
str
=
'train'
,
weights
:
Union
[
None
,
float
,
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
dict
]:
if
self
.
scale_input_to_tgt_size
:
inputs
=
torch
.
nn
.
functional
.
interpolate
(
inputs
,
reconstructions
.
shape
[
2
:],
mode
=
'bicubic'
,
antialias
=
True
)
if
self
.
dims
>
2
:
inputs
,
reconstructions
=
map
(
lambda
x
:
rearrange
(
x
,
'b c t h w -> (b t) c h w'
),
(
inputs
,
reconstructions
),
)
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
frame_indices
=
torch
.
randn
(
(
inputs
.
shape
[
0
],
inputs
.
shape
[
2
])).
topk
(
1
,
dim
=-
1
).
indices
from
sgm.modules.autoencoding.losses.video_loss
import
\
pick_video_frame
input_frames
=
pick_video_frame
(
inputs
,
frame_indices
)
recon_frames
=
pick_video_frame
(
reconstructions
,
frame_indices
)
p_loss
=
self
.
perceptual_loss
(
input_frames
.
contiguous
(),
recon_frames
.
contiguous
()).
mean
()
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
,
weighted_nll_loss
=
self
.
get_nll_loss
(
rec_loss
,
weights
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
global_step
>=
self
.
discriminator_iter_start
or
not
self
.
training
:
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
training
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
else
:
d_weight
=
torch
.
tensor
(
1.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
g_loss
=
torch
.
tensor
(
0.0
,
requires_grad
=
True
)
loss
=
weighted_nll_loss
+
d_weight
*
self
.
disc_factor
*
g_loss
log
=
dict
()
for
k
in
regularization_log
:
if
k
in
self
.
regularization_weights
:
loss
=
loss
+
self
.
regularization_weights
[
k
]
*
regularization_log
[
k
]
if
k
in
self
.
additional_log_keys
:
log
[
f
'
{
split
}
/
{
k
}
'
]
=
regularization_log
[
k
].
detach
().
float
(
).
mean
()
log
.
update
({
f
'
{
split
}
/loss/total'
:
loss
.
clone
().
detach
().
mean
(),
f
'
{
split
}
/loss/nll'
:
nll_loss
.
detach
().
mean
(),
f
'
{
split
}
/loss/rec'
:
rec_loss
.
detach
().
mean
(),
f
'
{
split
}
/loss/percep'
:
p_loss
.
detach
().
mean
(),
f
'
{
split
}
/loss/rec'
:
rec_loss
.
detach
().
mean
(),
f
'
{
split
}
/loss/g'
:
g_loss
.
detach
().
mean
(),
f
'
{
split
}
/scalars/logvar'
:
self
.
logvar
.
detach
(),
f
'
{
split
}
/scalars/d_weight'
:
d_weight
.
detach
(),
})
return
loss
,
log
elif
optimizer_idx
==
1
:
# second pass for discriminator update
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
if
global_step
>=
self
.
discriminator_iter_start
or
not
self
.
training
:
d_loss
=
self
.
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
else
:
d_loss
=
torch
.
tensor
(
0.0
,
requires_grad
=
True
)
log
=
{
f
'
{
split
}
/loss/disc'
:
d_loss
.
clone
().
detach
().
mean
(),
f
'
{
split
}
/logits/real'
:
logits_real
.
detach
().
mean
(),
f
'
{
split
}
/logits/fake'
:
logits_fake
.
detach
().
mean
(),
}
return
d_loss
,
log
else
:
raise
NotImplementedError
(
f
'Unknown optimizer_idx
{
optimizer_idx
}
'
)
def
get_nll_loss
(
self
,
rec_loss
:
torch
.
Tensor
,
weights
:
Optional
[
Union
[
float
,
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
return
nll_loss
,
weighted_nll_loss
flashvideo/sgm/modules/autoencoding/losses/lpips.py
0 → 100644
View file @
3b804999
import
torch
import
torch.nn
as
nn
from
....util
import
default
,
instantiate_from_config
from
..lpips.loss.lpips
import
LPIPS
class
LatentLPIPS
(
nn
.
Module
):
def
__init__
(
self
,
decoder_config
,
perceptual_weight
=
1.0
,
latent_weight
=
1.0
,
scale_input_to_tgt_size
=
False
,
scale_tgt_to_input_size
=
False
,
perceptual_weight_on_inputs
=
0.0
,
):
super
().
__init__
()
self
.
scale_input_to_tgt_size
=
scale_input_to_tgt_size
self
.
scale_tgt_to_input_size
=
scale_tgt_to_input_size
self
.
init_decoder
(
decoder_config
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
self
.
latent_weight
=
latent_weight
self
.
perceptual_weight_on_inputs
=
perceptual_weight_on_inputs
def
init_decoder
(
self
,
config
):
self
.
decoder
=
instantiate_from_config
(
config
)
if
hasattr
(
self
.
decoder
,
'encoder'
):
del
self
.
decoder
.
encoder
def
forward
(
self
,
latent_inputs
,
latent_predictions
,
image_inputs
,
split
=
'train'
):
log
=
dict
()
loss
=
(
latent_inputs
-
latent_predictions
)
**
2
log
[
f
'
{
split
}
/latent_l2_loss'
]
=
loss
.
mean
().
detach
()
image_reconstructions
=
None
if
self
.
perceptual_weight
>
0.0
:
image_reconstructions
=
self
.
decoder
.
decode
(
latent_predictions
)
image_targets
=
self
.
decoder
.
decode
(
latent_inputs
)
perceptual_loss
=
self
.
perceptual_loss
(
image_targets
.
contiguous
(),
image_reconstructions
.
contiguous
())
loss
=
self
.
latent_weight
*
loss
.
mean
(
)
+
self
.
perceptual_weight
*
perceptual_loss
.
mean
()
log
[
f
'
{
split
}
/perceptual_loss'
]
=
perceptual_loss
.
mean
().
detach
()
if
self
.
perceptual_weight_on_inputs
>
0.0
:
image_reconstructions
=
default
(
image_reconstructions
,
self
.
decoder
.
decode
(
latent_predictions
))
if
self
.
scale_input_to_tgt_size
:
image_inputs
=
torch
.
nn
.
functional
.
interpolate
(
image_inputs
,
image_reconstructions
.
shape
[
2
:],
mode
=
'bicubic'
,
antialias
=
True
,
)
elif
self
.
scale_tgt_to_input_size
:
image_reconstructions
=
torch
.
nn
.
functional
.
interpolate
(
image_reconstructions
,
image_inputs
.
shape
[
2
:],
mode
=
'bicubic'
,
antialias
=
True
,
)
perceptual_loss2
=
self
.
perceptual_loss
(
image_inputs
.
contiguous
(),
image_reconstructions
.
contiguous
())
loss
=
loss
+
self
.
perceptual_weight_on_inputs
*
perceptual_loss2
.
mean
(
)
log
[
f
'
{
split
}
/perceptual_loss_on_inputs'
]
=
perceptual_loss2
.
mean
(
).
detach
()
return
loss
,
log
flashvideo/sgm/modules/autoencoding/losses/video_loss.py
0 → 100644
View file @
3b804999
from
math
import
log2
from
typing
import
Any
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision
from
beartype
import
beartype
from
einops
import
einsum
,
rearrange
,
repeat
from
einops.layers.torch
import
Rearrange
from
kornia.filters
import
filter3d
from
sgm.modules.autoencoding.vqvae.movq_enc_3d
import
(
CausalConv3d
,
DownSample3D
)
from
sgm.util
import
instantiate_from_config
from
torch
import
Tensor
from
torch.autograd
import
grad
as
torch_grad
from
torch.cuda.amp
import
autocast
from
torchvision.models
import
VGG16_Weights
from
..magvit2_pytorch
import
FeedForward
,
LinearSpaceAttention
,
Residual
from
.lpips
import
LPIPS
def
exists
(
v
):
return
v
is
not
None
def
pair
(
t
):
return
t
if
isinstance
(
t
,
tuple
)
else
(
t
,
t
)
def
leaky_relu
(
p
=
0.1
):
return
nn
.
LeakyReLU
(
p
)
def
hinge_discr_loss
(
fake
,
real
):
return
(
F
.
relu
(
1
+
fake
)
+
F
.
relu
(
1
-
real
)).
mean
()
def
hinge_gen_loss
(
fake
):
return
-
fake
.
mean
()
@
autocast
(
enabled
=
False
)
@
beartype
def
grad_layer_wrt_loss
(
loss
:
Tensor
,
layer
:
nn
.
Parameter
):
return
torch_grad
(
outputs
=
loss
,
inputs
=
layer
,
grad_outputs
=
torch
.
ones_like
(
loss
),
retain_graph
=
True
)[
0
].
detach
()
def
pick_video_frame
(
video
,
frame_indices
):
batch
,
device
=
video
.
shape
[
0
],
video
.
device
video
=
rearrange
(
video
,
'b c f ... -> b f c ...'
)
batch_indices
=
torch
.
arange
(
batch
,
device
=
device
)
batch_indices
=
rearrange
(
batch_indices
,
'b -> b 1'
)
images
=
video
[
batch_indices
,
frame_indices
]
images
=
rearrange
(
images
,
'b 1 c ... -> b c ...'
)
return
images
def
gradient_penalty
(
images
,
output
):
batch_size
=
images
.
shape
[
0
]
gradients
=
torch_grad
(
outputs
=
output
,
inputs
=
images
,
grad_outputs
=
torch
.
ones
(
output
.
size
(),
device
=
images
.
device
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
,
)[
0
]
gradients
=
rearrange
(
gradients
,
'b ... -> b (...)'
)
return
((
gradients
.
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class
Blur
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
f
=
torch
.
Tensor
([
1
,
2
,
1
])
self
.
register_buffer
(
'f'
,
f
)
def
forward
(
self
,
x
,
space_only
=
False
,
time_only
=
False
):
assert
not
(
space_only
and
time_only
)
f
=
self
.
f
if
space_only
:
f
=
einsum
(
'i, j -> i j'
,
f
,
f
)
f
=
rearrange
(
f
,
'... -> 1 1 ...'
)
elif
time_only
:
f
=
rearrange
(
f
,
'f -> 1 f 1 1'
)
else
:
f
=
einsum
(
'i, j, k -> i j k'
,
f
,
f
,
f
)
f
=
rearrange
(
f
,
'... -> 1 ...'
)
is_images
=
x
.
ndim
==
4
if
is_images
:
x
=
rearrange
(
x
,
'b c h w -> b c 1 h w'
)
out
=
filter3d
(
x
,
f
,
normalized
=
True
)
if
is_images
:
out
=
rearrange
(
out
,
'b c 1 h w -> b c h w'
)
return
out
class
DiscriminatorBlock
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
,
filters
,
downsample
=
True
,
antialiased_downsample
=
True
):
super
().
__init__
()
self
.
conv_res
=
nn
.
Conv2d
(
input_channels
,
filters
,
1
,
stride
=
(
2
if
downsample
else
1
))
self
.
net
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channels
,
filters
,
3
,
padding
=
1
),
leaky_relu
(),
nn
.
Conv2d
(
filters
,
filters
,
3
,
padding
=
1
),
leaky_relu
(),
)
self
.
maybe_blur
=
Blur
()
if
antialiased_downsample
else
None
self
.
downsample
=
(
nn
.
Sequential
(
Rearrange
(
'b c (h p1) (w p2) -> b (c p1 p2) h w'
,
p1
=
2
,
p2
=
2
),
nn
.
Conv2d
(
filters
*
4
,
filters
,
1
))
if
downsample
else
None
)
def
forward
(
self
,
x
):
res
=
self
.
conv_res
(
x
)
x
=
self
.
net
(
x
)
if
exists
(
self
.
downsample
):
if
exists
(
self
.
maybe_blur
):
x
=
self
.
maybe_blur
(
x
,
space_only
=
True
)
x
=
self
.
downsample
(
x
)
x
=
(
x
+
res
)
*
(
2
**-
0.5
)
return
x
class
Discriminator
(
nn
.
Module
):
@
beartype
def
__init__
(
self
,
*
,
dim
,
image_size
,
channels
=
3
,
max_dim
=
512
,
attn_heads
=
8
,
attn_dim_head
=
32
,
linear_attn_dim_head
=
8
,
linear_attn_heads
=
16
,
ff_mult
=
4
,
antialiased_downsample
=
False
,
):
super
().
__init__
()
image_size
=
pair
(
image_size
)
min_image_resolution
=
min
(
image_size
)
num_layers
=
int
(
log2
(
min_image_resolution
)
-
2
)
blocks
=
[]
layer_dims
=
[
channels
]
+
[(
dim
*
4
)
*
(
2
**
i
)
for
i
in
range
(
num_layers
+
1
)]
layer_dims
=
[
min
(
layer_dim
,
max_dim
)
for
layer_dim
in
layer_dims
]
layer_dims_in_out
=
tuple
(
zip
(
layer_dims
[:
-
1
],
layer_dims
[
1
:]))
blocks
=
[]
attn_blocks
=
[]
image_resolution
=
min_image_resolution
for
ind
,
(
in_chan
,
out_chan
)
in
enumerate
(
layer_dims_in_out
):
num_layer
=
ind
+
1
is_not_last
=
ind
!=
(
len
(
layer_dims_in_out
)
-
1
)
block
=
DiscriminatorBlock
(
in_chan
,
out_chan
,
downsample
=
is_not_last
,
antialiased_downsample
=
antialiased_downsample
)
attn_block
=
nn
.
Sequential
(
Residual
(
LinearSpaceAttention
(
dim
=
out_chan
,
heads
=
linear_attn_heads
,
dim_head
=
linear_attn_dim_head
)),
Residual
(
FeedForward
(
dim
=
out_chan
,
mult
=
ff_mult
,
images
=
True
)),
)
blocks
.
append
(
nn
.
ModuleList
([
block
,
attn_block
]))
image_resolution
//=
2
self
.
blocks
=
nn
.
ModuleList
(
blocks
)
dim_last
=
layer_dims
[
-
1
]
downsample_factor
=
2
**
num_layers
last_fmap_size
=
tuple
(
map
(
lambda
n
:
n
//
downsample_factor
,
image_size
))
latent_dim
=
last_fmap_size
[
0
]
*
last_fmap_size
[
1
]
*
dim_last
self
.
to_logits
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim_last
,
dim_last
,
3
,
padding
=
1
),
leaky_relu
(),
Rearrange
(
'b ... -> b (...)'
),
nn
.
Linear
(
latent_dim
,
1
),
Rearrange
(
'b 1 -> b'
),
)
def
forward
(
self
,
x
):
for
block
,
attn_block
in
self
.
blocks
:
x
=
block
(
x
)
x
=
attn_block
(
x
)
return
self
.
to_logits
(
x
)
class
DiscriminatorBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
,
filters
,
antialiased_downsample
=
True
,
):
super
().
__init__
()
self
.
conv_res
=
nn
.
Conv3d
(
input_channels
,
filters
,
1
,
stride
=
2
)
self
.
net
=
nn
.
Sequential
(
nn
.
Conv3d
(
input_channels
,
filters
,
3
,
padding
=
1
),
leaky_relu
(),
nn
.
Conv3d
(
filters
,
filters
,
3
,
padding
=
1
),
leaky_relu
(),
)
self
.
maybe_blur
=
Blur
()
if
antialiased_downsample
else
None
self
.
downsample
=
nn
.
Sequential
(
Rearrange
(
'b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w'
,
p1
=
2
,
p2
=
2
,
p3
=
2
),
nn
.
Conv3d
(
filters
*
8
,
filters
,
1
),
)
def
forward
(
self
,
x
):
res
=
self
.
conv_res
(
x
)
x
=
self
.
net
(
x
)
if
exists
(
self
.
downsample
):
if
exists
(
self
.
maybe_blur
):
x
=
self
.
maybe_blur
(
x
,
space_only
=
True
)
x
=
self
.
downsample
(
x
)
x
=
(
x
+
res
)
*
(
2
**-
0.5
)
return
x
class
DiscriminatorBlock3DWithfirstframe
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
,
filters
,
antialiased_downsample
=
True
,
pad_mode
=
'first'
,
):
super
().
__init__
()
self
.
downsample_res
=
DownSample3D
(
in_channels
=
input_channels
,
out_channels
=
filters
,
with_conv
=
True
,
compress_time
=
True
,
)
self
.
net
=
nn
.
Sequential
(
CausalConv3d
(
input_channels
,
filters
,
kernel_size
=
3
,
pad_mode
=
pad_mode
),
leaky_relu
(),
CausalConv3d
(
filters
,
filters
,
kernel_size
=
3
,
pad_mode
=
pad_mode
),
leaky_relu
(),
)
self
.
maybe_blur
=
Blur
()
if
antialiased_downsample
else
None
self
.
downsample
=
DownSample3D
(
in_channels
=
filters
,
out_channels
=
filters
,
with_conv
=
True
,
compress_time
=
True
,
)
def
forward
(
self
,
x
):
res
=
self
.
downsample_res
(
x
)
x
=
self
.
net
(
x
)
if
exists
(
self
.
downsample
):
if
exists
(
self
.
maybe_blur
):
x
=
self
.
maybe_blur
(
x
,
space_only
=
True
)
x
=
self
.
downsample
(
x
)
x
=
(
x
+
res
)
*
(
2
**-
0.5
)
return
x
class
Discriminator3D
(
nn
.
Module
):
@
beartype
def
__init__
(
self
,
*
,
dim
,
image_size
,
frame_num
,
channels
=
3
,
max_dim
=
512
,
linear_attn_dim_head
=
8
,
linear_attn_heads
=
16
,
ff_mult
=
4
,
antialiased_downsample
=
False
,
):
super
().
__init__
()
image_size
=
pair
(
image_size
)
min_image_resolution
=
min
(
image_size
)
num_layers
=
int
(
log2
(
min_image_resolution
)
-
2
)
temporal_num_layers
=
int
(
log2
(
frame_num
))
self
.
temporal_num_layers
=
temporal_num_layers
layer_dims
=
[
channels
]
+
[(
dim
*
4
)
*
(
2
**
i
)
for
i
in
range
(
num_layers
+
1
)]
layer_dims
=
[
min
(
layer_dim
,
max_dim
)
for
layer_dim
in
layer_dims
]
layer_dims_in_out
=
tuple
(
zip
(
layer_dims
[:
-
1
],
layer_dims
[
1
:]))
blocks
=
[]
image_resolution
=
min_image_resolution
frame_resolution
=
frame_num
for
ind
,
(
in_chan
,
out_chan
)
in
enumerate
(
layer_dims_in_out
):
num_layer
=
ind
+
1
is_not_last
=
ind
!=
(
len
(
layer_dims_in_out
)
-
1
)
if
ind
<
temporal_num_layers
:
block
=
DiscriminatorBlock3D
(
in_chan
,
out_chan
,
antialiased_downsample
=
antialiased_downsample
,
)
blocks
.
append
(
block
)
frame_resolution
//=
2
else
:
block
=
DiscriminatorBlock
(
in_chan
,
out_chan
,
downsample
=
is_not_last
,
antialiased_downsample
=
antialiased_downsample
,
)
attn_block
=
nn
.
Sequential
(
Residual
(
LinearSpaceAttention
(
dim
=
out_chan
,
heads
=
linear_attn_heads
,
dim_head
=
linear_attn_dim_head
)),
Residual
(
FeedForward
(
dim
=
out_chan
,
mult
=
ff_mult
,
images
=
True
)),
)
blocks
.
append
(
nn
.
ModuleList
([
block
,
attn_block
]))
image_resolution
//=
2
self
.
blocks
=
nn
.
ModuleList
(
blocks
)
dim_last
=
layer_dims
[
-
1
]
downsample_factor
=
2
**
num_layers
last_fmap_size
=
tuple
(
map
(
lambda
n
:
n
//
downsample_factor
,
image_size
))
latent_dim
=
last_fmap_size
[
0
]
*
last_fmap_size
[
1
]
*
dim_last
self
.
to_logits
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim_last
,
dim_last
,
3
,
padding
=
1
),
leaky_relu
(),
Rearrange
(
'b ... -> b (...)'
),
nn
.
Linear
(
latent_dim
,
1
),
Rearrange
(
'b 1 -> b'
),
)
def
forward
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
blocks
):
if
i
<
self
.
temporal_num_layers
:
x
=
layer
(
x
)
if
i
==
self
.
temporal_num_layers
-
1
:
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
else
:
block
,
attn_block
=
layer
x
=
block
(
x
)
x
=
attn_block
(
x
)
return
self
.
to_logits
(
x
)
class
Discriminator3DWithfirstframe
(
nn
.
Module
):
@
beartype
def
__init__
(
self
,
*
,
dim
,
image_size
,
frame_num
,
channels
=
3
,
max_dim
=
512
,
linear_attn_dim_head
=
8
,
linear_attn_heads
=
16
,
ff_mult
=
4
,
antialiased_downsample
=
False
,
):
super
().
__init__
()
image_size
=
pair
(
image_size
)
min_image_resolution
=
min
(
image_size
)
num_layers
=
int
(
log2
(
min_image_resolution
)
-
2
)
temporal_num_layers
=
int
(
log2
(
frame_num
))
self
.
temporal_num_layers
=
temporal_num_layers
layer_dims
=
[
channels
]
+
[(
dim
*
4
)
*
(
2
**
i
)
for
i
in
range
(
num_layers
+
1
)]
layer_dims
=
[
min
(
layer_dim
,
max_dim
)
for
layer_dim
in
layer_dims
]
layer_dims_in_out
=
tuple
(
zip
(
layer_dims
[:
-
1
],
layer_dims
[
1
:]))
blocks
=
[]
image_resolution
=
min_image_resolution
frame_resolution
=
frame_num
for
ind
,
(
in_chan
,
out_chan
)
in
enumerate
(
layer_dims_in_out
):
num_layer
=
ind
+
1
is_not_last
=
ind
!=
(
len
(
layer_dims_in_out
)
-
1
)
if
ind
<
temporal_num_layers
:
block
=
DiscriminatorBlock3DWithfirstframe
(
in_chan
,
out_chan
,
antialiased_downsample
=
antialiased_downsample
,
)
blocks
.
append
(
block
)
frame_resolution
//=
2
else
:
block
=
DiscriminatorBlock
(
in_chan
,
out_chan
,
downsample
=
is_not_last
,
antialiased_downsample
=
antialiased_downsample
,
)
attn_block
=
nn
.
Sequential
(
Residual
(
LinearSpaceAttention
(
dim
=
out_chan
,
heads
=
linear_attn_heads
,
dim_head
=
linear_attn_dim_head
)),
Residual
(
FeedForward
(
dim
=
out_chan
,
mult
=
ff_mult
,
images
=
True
)),
)
blocks
.
append
(
nn
.
ModuleList
([
block
,
attn_block
]))
image_resolution
//=
2
self
.
blocks
=
nn
.
ModuleList
(
blocks
)
dim_last
=
layer_dims
[
-
1
]
downsample_factor
=
2
**
num_layers
last_fmap_size
=
tuple
(
map
(
lambda
n
:
n
//
downsample_factor
,
image_size
))
latent_dim
=
last_fmap_size
[
0
]
*
last_fmap_size
[
1
]
*
dim_last
self
.
to_logits
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim_last
,
dim_last
,
3
,
padding
=
1
),
leaky_relu
(),
Rearrange
(
'b ... -> b (...)'
),
nn
.
Linear
(
latent_dim
,
1
),
Rearrange
(
'b 1 -> b'
),
)
def
forward
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
blocks
):
if
i
<
self
.
temporal_num_layers
:
x
=
layer
(
x
)
if
i
==
self
.
temporal_num_layers
-
1
:
x
=
x
.
mean
(
dim
=
2
)
# x = rearrange(x, "b c f h w -> (b f) c h w")
else
:
block
,
attn_block
=
layer
x
=
block
(
x
)
x
=
attn_block
(
x
)
return
self
.
to_logits
(
x
)
class
VideoAutoencoderLoss
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
perceptual_weight
=
1
,
adversarial_loss_weight
=
0
,
multiscale_adversarial_loss_weight
=
0
,
grad_penalty_loss_weight
=
0
,
quantizer_aux_loss_weight
=
0
,
vgg_weights
=
VGG16_Weights
.
DEFAULT
,
discr_kwargs
=
None
,
discr_3d_kwargs
=
None
,
):
super
().
__init__
()
self
.
disc_start
=
disc_start
self
.
perceptual_weight
=
perceptual_weight
self
.
adversarial_loss_weight
=
adversarial_loss_weight
self
.
multiscale_adversarial_loss_weight
=
multiscale_adversarial_loss_weight
self
.
grad_penalty_loss_weight
=
grad_penalty_loss_weight
self
.
quantizer_aux_loss_weight
=
quantizer_aux_loss_weight
if
self
.
perceptual_weight
>
0
:
self
.
perceptual_model
=
LPIPS
().
eval
()
# self.vgg = torchvision.models.vgg16(pretrained = True)
# self.vgg.requires_grad_(False)
# if self.adversarial_loss_weight > 0:
# self.discr = Discriminator(**discr_kwargs)
# else:
# self.discr = None
# if self.multiscale_adversarial_loss_weight > 0:
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
# else:
# self.multiscale_discrs = None
if
discr_kwargs
is
not
None
:
self
.
discr
=
Discriminator
(
**
discr_kwargs
)
else
:
self
.
discr
=
None
if
discr_3d_kwargs
is
not
None
:
# self.discr_3d = Discriminator3D(**discr_3d_kwargs)
self
.
discr_3d
=
instantiate_from_config
(
discr_3d_kwargs
)
else
:
self
.
discr_3d
=
None
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
self
.
register_buffer
(
'zero'
,
torch
.
tensor
(
0.0
),
persistent
=
False
)
def
get_trainable_params
(
self
)
->
Any
:
params
=
[]
if
self
.
discr
is
not
None
:
params
+=
list
(
self
.
discr
.
parameters
())
if
self
.
discr_3d
is
not
None
:
params
+=
list
(
self
.
discr_3d
.
parameters
())
# if self.multiscale_discrs is not None:
# for discr in self.multiscale_discrs:
# params += list(discr.parameters())
return
params
def
get_trainable_parameters
(
self
)
->
Any
:
return
self
.
get_trainable_params
()
def
forward
(
self
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
aux_losses
=
None
,
last_layer
=
None
,
split
=
'train'
,
):
batch
,
channels
,
frames
=
inputs
.
shape
[:
3
]
if
optimizer_idx
==
0
:
recon_loss
=
F
.
mse_loss
(
inputs
,
reconstructions
)
if
self
.
perceptual_weight
>
0
:
frame_indices
=
torch
.
randn
(
(
batch
,
frames
)).
topk
(
1
,
dim
=-
1
).
indices
input_frames
=
pick_video_frame
(
inputs
,
frame_indices
)
recon_frames
=
pick_video_frame
(
reconstructions
,
frame_indices
)
perceptual_loss
=
self
.
perceptual_model
(
input_frames
.
contiguous
(),
recon_frames
.
contiguous
()).
mean
()
else
:
perceptual_loss
=
self
.
zero
if
global_step
>=
self
.
disc_start
or
not
self
.
training
or
self
.
adversarial_loss_weight
==
0
:
gen_loss
=
self
.
zero
adaptive_weight
=
0
else
:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# fake_logits = self.discr(recon_video_frames)
fake_logits
=
self
.
discr_3d
(
reconstructions
)
gen_loss
=
hinge_gen_loss
(
fake_logits
)
adaptive_weight
=
1
if
self
.
perceptual_weight
>
0
and
last_layer
is
not
None
:
norm_grad_wrt_perceptual_loss
=
grad_layer_wrt_loss
(
perceptual_loss
,
last_layer
).
norm
(
p
=
2
)
norm_grad_wrt_gen_loss
=
grad_layer_wrt_loss
(
gen_loss
,
last_layer
).
norm
(
p
=
2
)
adaptive_weight
=
norm_grad_wrt_perceptual_loss
/
norm_grad_wrt_gen_loss
.
clamp
(
min
=
1e-3
)
adaptive_weight
.
clamp_
(
max
=
1e3
)
if
torch
.
isnan
(
adaptive_weight
).
any
():
adaptive_weight
=
1
# multiscale discriminator losses
# multiscale_gen_losses = []
# multiscale_gen_adaptive_weights = []
# if self.multiscale_adversarial_loss_weight > 0:
# if not exists(recon_video_frames):
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# for discr in self.multiscale_discrs:
# fake_logits = recon_video_frames
# multiscale_gen_loss = hinge_gen_loss(fake_logits)
# multiscale_gen_losses.append(multiscale_gen_loss)
# multiscale_adaptive_weight = 1.
# if exists(norm_grad_wrt_perceptual_loss):
# norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
# multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
# multiscale_adaptive_weight.clamp_(max = 1e3)
# multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
# weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
# else:
# weighted_multiscale_gen_losses = self.zero
if
aux_losses
is
None
:
aux_losses
=
self
.
zero
total_loss
=
(
recon_loss
+
aux_losses
*
self
.
quantizer_aux_loss_weight
+
perceptual_loss
*
self
.
perceptual_weight
+
gen_loss
*
self
.
adversarial_loss_weight
)
# gen_loss * adaptive_weight * self.adversarial_loss_weight + \
# weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
log
=
{
f
'
{
split
}
/total_loss'
:
total_loss
.
detach
(),
f
'
{
split
}
/recon_loss'
:
recon_loss
.
detach
(),
f
'
{
split
}
/perceptual_loss'
:
perceptual_loss
.
detach
(),
f
'
{
split
}
/gen_loss'
:
gen_loss
.
detach
(),
f
'
{
split
}
/aux_losses'
:
aux_losses
.
detach
(),
# "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
f
'
{
split
}
/adaptive_weight'
:
adaptive_weight
,
# "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
}
return
total_loss
,
log
if
optimizer_idx
==
1
:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# real = pick_video_frame(inputs, frame_indices)
# fake = pick_video_frame(reconstructions, frame_indices)
# apply_gradient_penalty = self.grad_penalty_loss_weight > 0
# if apply_gradient_penalty:
# real = real.requires_grad_()
# real_logits = self.discr(real)
# fake_logits = self.discr(fake.detach())
apply_gradient_penalty
=
self
.
grad_penalty_loss_weight
>
0
if
apply_gradient_penalty
:
inputs
=
inputs
.
requires_grad_
()
real_logits
=
self
.
discr_3d
(
inputs
)
fake_logits
=
self
.
discr_3d
(
reconstructions
.
detach
())
discr_loss
=
hinge_discr_loss
(
fake_logits
,
real_logits
)
# # multiscale discriminators
# multiscale_discr_losses = []
# if self.multiscale_adversarial_loss_weight > 0:
# for discr in self.multiscale_discrs:
# multiscale_real_logits = discr(inputs)
# multiscale_fake_logits = discr(reconstructions.detach())
# multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
# multiscale_discr_losses.append(multiscale_discr_loss)
# else:
# multiscale_discr_losses.append(self.zero)
# gradient penalty
if
apply_gradient_penalty
:
# gradient_penalty_loss = gradient_penalty(real, real_logits)
gradient_penalty_loss
=
gradient_penalty
(
inputs
,
real_logits
)
else
:
gradient_penalty_loss
=
self
.
zero
total_loss
=
discr_loss
+
self
.
grad_penalty_loss_weight
*
gradient_penalty_loss
# self.grad_penalty_loss_weight * gradient_penalty_loss + \
# sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
log
=
{
f
'
{
split
}
/total_disc_loss'
:
total_loss
.
detach
(),
f
'
{
split
}
/discr_loss'
:
discr_loss
.
detach
(),
f
'
{
split
}
/grad_penalty_loss'
:
gradient_penalty_loss
.
detach
(),
# "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
f
'
{
split
}
/logits_real'
:
real_logits
.
detach
().
mean
(),
f
'
{
split
}
/logits_fake'
:
fake_logits
.
detach
().
mean
(),
}
return
total_loss
,
log
flashvideo/sgm/modules/autoencoding/lpips/__init__.py
0 → 100644
View file @
3b804999
flashvideo/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
flashvideo/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-310.pyc
0 → 100644
View file @
3b804999
File added
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment