Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fb8fae6f
Unverified
Commit
fb8fae6f
authored
Apr 06, 2023
by
NatalieC323
Committed by
GitHub
Apr 06, 2023
Browse files
Revert "[dreambooth] fixing the incompatibity in requirements.txt (#3190) (#3378)" (#3481)
parent
891b8e7f
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
98 additions
and
124 deletions
+98
-124
examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
...ges/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
+8
-8
examples/images/diffusion/configs/train_colossalai.yaml
examples/images/diffusion/configs/train_colossalai.yaml
+11
-11
examples/images/diffusion/configs/train_colossalai_cifar10.yaml
...es/images/diffusion/configs/train_colossalai_cifar10.yaml
+8
-8
examples/images/diffusion/configs/train_ddp.yaml
examples/images/diffusion/configs/train_ddp.yaml
+7
-7
examples/images/diffusion/ldm/models/autoencoder.py
examples/images/diffusion/ldm/models/autoencoder.py
+3
-2
examples/images/diffusion/ldm/models/diffusion/classifier.py
examples/images/diffusion/ldm/models/diffusion/classifier.py
+4
-5
examples/images/diffusion/ldm/models/diffusion/ddpm.py
examples/images/diffusion/ldm/models/diffusion/ddpm.py
+10
-12
examples/images/diffusion/main.py
examples/images/diffusion/main.py
+35
-59
examples/images/diffusion/scripts/img2img.py
examples/images/diffusion/scripts/img2img.py
+2
-2
examples/images/diffusion/scripts/inpaint.py
examples/images/diffusion/scripts/inpaint.py
+2
-2
examples/images/diffusion/scripts/knn2img.py
examples/images/diffusion/scripts/knn2img.py
+2
-3
examples/images/diffusion/scripts/sample_diffusion.py
examples/images/diffusion/scripts/sample_diffusion.py
+2
-1
examples/images/diffusion/scripts/tests/test_checkpoint.py
examples/images/diffusion/scripts/tests/test_checkpoint.py
+2
-2
examples/images/diffusion/scripts/txt2img.py
examples/images/diffusion/scripts/txt2img.py
+2
-2
No files found.
examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
View file @
fb8fae6f
model
:
model
:
base_learning_rate
:
1.0e-4
base_learning_rate
:
1.0e-4
#
target: ldm.models.diffusion.ddpm.LatentDiffusion
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
params
:
parameterization
:
"
v"
parameterization
:
"
v"
linear_start
:
0.00085
linear_start
:
0.00085
...
@@ -20,7 +20,7 @@ model:
...
@@ -20,7 +20,7 @@ model:
use_ema
:
False
use_ema
:
False
scheduler_config
:
# 10000 warmup steps
scheduler_config
:
# 10000 warmup steps
#
target: ldm.lr_scheduler.LambdaLinearScheduler
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
params
:
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
...
@@ -30,7 +30,7 @@ model:
...
@@ -30,7 +30,7 @@ model:
unet_config
:
unet_config
:
#
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
params
:
use_checkpoint
:
True
use_checkpoint
:
True
use_fp16
:
True
use_fp16
:
True
...
@@ -49,7 +49,7 @@ model:
...
@@ -49,7 +49,7 @@ model:
legacy
:
False
legacy
:
False
first_stage_config
:
first_stage_config
:
#
target: ldm.models.autoencoder.AutoencoderKL
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
params
:
embed_dim
:
4
embed_dim
:
4
monitor
:
val/rec_loss
monitor
:
val/rec_loss
...
@@ -73,13 +73,13 @@ model:
...
@@ -73,13 +73,13 @@ model:
target
:
torch.nn.Identity
target
:
torch.nn.Identity
cond_stage_config
:
cond_stage_config
:
#
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
params
:
freeze
:
True
freeze
:
True
layer
:
"
penultimate"
layer
:
"
penultimate"
data
:
data
:
#
target: main.DataModuleFromConfig
target
:
main.DataModuleFromConfig
params
:
params
:
batch_size
:
16
batch_size
:
16
num_workers
:
4
num_workers
:
4
...
@@ -105,7 +105,7 @@ lightning:
...
@@ -105,7 +105,7 @@ lightning:
precision
:
16
precision
:
16
auto_select_gpus
:
False
auto_select_gpus
:
False
strategy
:
strategy
:
#
target: strategies.ColossalAIStrategy
target
:
strategies.ColossalAIStrategy
params
:
params
:
use_chunk
:
True
use_chunk
:
True
enable_distributed_storage
:
True
enable_distributed_storage
:
True
...
@@ -120,7 +120,7 @@ lightning:
...
@@ -120,7 +120,7 @@ lightning:
logger_config
:
logger_config
:
wandb
:
wandb
:
#
target: loggers.WandbLogger
target
:
loggers.WandbLogger
params
:
params
:
name
:
nowname
name
:
nowname
save_dir
:
"
/tmp/diff_log/"
save_dir
:
"
/tmp/diff_log/"
...
...
examples/images/diffusion/configs/train_colossalai.yaml
View file @
fb8fae6f
model
:
model
:
base_learning_rate
:
1.0e-4
base_learning_rate
:
1.0e-4
#
target: ldm.models.diffusion.ddpm.LatentDiffusion
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
params
:
parameterization
:
"
v"
parameterization
:
"
v"
linear_start
:
0.00085
linear_start
:
0.00085
...
@@ -19,7 +19,7 @@ model:
...
@@ -19,7 +19,7 @@ model:
use_ema
:
False
# we set this to false because this is an inference only config
use_ema
:
False
# we set this to false because this is an inference only config
scheduler_config
:
# 10000 warmup steps
scheduler_config
:
# 10000 warmup steps
#
target: ldm.lr_scheduler.LambdaLinearScheduler
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
params
:
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
...
@@ -29,7 +29,7 @@ model:
...
@@ -29,7 +29,7 @@ model:
unet_config
:
unet_config
:
#
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
params
:
use_checkpoint
:
True
use_checkpoint
:
True
use_fp16
:
True
use_fp16
:
True
...
@@ -48,7 +48,7 @@ model:
...
@@ -48,7 +48,7 @@ model:
legacy
:
False
legacy
:
False
first_stage_config
:
first_stage_config
:
#
target: ldm.models.autoencoder.AutoencoderKL
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
params
:
embed_dim
:
4
embed_dim
:
4
monitor
:
val/rec_loss
monitor
:
val/rec_loss
...
@@ -69,16 +69,16 @@ model:
...
@@ -69,16 +69,16 @@ model:
attn_resolutions
:
[]
attn_resolutions
:
[]
dropout
:
0.0
dropout
:
0.0
lossconfig
:
lossconfig
:
#
target: torch.nn.Identity
target
:
torch.nn.Identity
cond_stage_config
:
cond_stage_config
:
#
target:
#
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
params
:
freeze
:
True
freeze
:
True
layer
:
"
penultimate"
layer
:
"
penultimate"
data
:
data
:
#
target:
#
main.DataModuleFromConfig
target
:
main.DataModuleFromConfig
params
:
params
:
batch_size
:
128
batch_size
:
128
wrap
:
False
wrap
:
False
...
@@ -88,20 +88,20 @@ data:
...
@@ -88,20 +88,20 @@ data:
train
:
train
:
target
:
ldm.data.base.Txt2ImgIterableBaseDataset
target
:
ldm.data.base.Txt2ImgIterableBaseDataset
params
:
params
:
file_path
:
/data/scratch/diffuser/laion_part0/
file_path
:
# YOUR DATASET_PATH
world_size
:
1
world_size
:
1
rank
:
0
rank
:
0
lightning
:
lightning
:
trainer
:
trainer
:
accelerator
:
'
gpu'
accelerator
:
'
gpu'
devices
:
2
devices
:
8
log_gpu_memory
:
all
log_gpu_memory
:
all
max_epochs
:
2
max_epochs
:
2
precision
:
16
precision
:
16
auto_select_gpus
:
False
auto_select_gpus
:
False
strategy
:
strategy
:
#
target:
#
strategies.ColossalAIStrategy
target
:
strategies.ColossalAIStrategy
params
:
params
:
use_chunk
:
True
use_chunk
:
True
enable_distributed_storage
:
True
enable_distributed_storage
:
True
...
@@ -116,7 +116,7 @@ lightning:
...
@@ -116,7 +116,7 @@ lightning:
logger_config
:
logger_config
:
wandb
:
wandb
:
#
target:
#
loggers.WandbLogger
target
:
loggers.WandbLogger
params
:
params
:
name
:
nowname
name
:
nowname
save_dir
:
"
/tmp/diff_log/"
save_dir
:
"
/tmp/diff_log/"
...
...
examples/images/diffusion/configs/train_colossalai_cifar10.yaml
View file @
fb8fae6f
model
:
model
:
base_learning_rate
:
1.0e-4
base_learning_rate
:
1.0e-4
#
target: ldm.models.diffusion.ddpm.LatentDiffusion
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
params
:
parameterization
:
"
v"
parameterization
:
"
v"
linear_start
:
0.00085
linear_start
:
0.00085
...
@@ -19,7 +19,7 @@ model:
...
@@ -19,7 +19,7 @@ model:
use_ema
:
False
# we set this to false because this is an inference only config
use_ema
:
False
# we set this to false because this is an inference only config
scheduler_config
:
# 10000 warmup steps
scheduler_config
:
# 10000 warmup steps
#
target: ldm.lr_scheduler.LambdaLinearScheduler
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
params
:
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
...
@@ -29,7 +29,7 @@ model:
...
@@ -29,7 +29,7 @@ model:
unet_config
:
unet_config
:
#
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
params
:
use_checkpoint
:
True
use_checkpoint
:
True
use_fp16
:
True
use_fp16
:
True
...
@@ -48,7 +48,7 @@ model:
...
@@ -48,7 +48,7 @@ model:
legacy
:
False
legacy
:
False
first_stage_config
:
first_stage_config
:
#
target: ldm.models.autoencoder.AutoencoderKL
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
params
:
embed_dim
:
4
embed_dim
:
4
monitor
:
val/rec_loss
monitor
:
val/rec_loss
...
@@ -69,16 +69,16 @@ model:
...
@@ -69,16 +69,16 @@ model:
attn_resolutions
:
[]
attn_resolutions
:
[]
dropout
:
0.0
dropout
:
0.0
lossconfig
:
lossconfig
:
#
target: torch.nn.Identity
target
:
torch.nn.Identity
cond_stage_config
:
cond_stage_config
:
#
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
params
:
freeze
:
True
freeze
:
True
layer
:
"
penultimate"
layer
:
"
penultimate"
data
:
data
:
#
target: main.DataModuleFromConfig
target
:
main.DataModuleFromConfig
params
:
params
:
batch_size
:
4
batch_size
:
4
num_workers
:
4
num_workers
:
4
...
@@ -105,7 +105,7 @@ lightning:
...
@@ -105,7 +105,7 @@ lightning:
precision
:
16
precision
:
16
auto_select_gpus
:
False
auto_select_gpus
:
False
strategy
:
strategy
:
#
target: strategies.ColossalAIStrategy
target
:
strategies.ColossalAIStrategy
params
:
params
:
use_chunk
:
True
use_chunk
:
True
enable_distributed_storage
:
True
enable_distributed_storage
:
True
...
...
examples/images/diffusion/configs/train_ddp.yaml
View file @
fb8fae6f
model
:
model
:
base_learning_rate
:
1.0e-4
base_learning_rate
:
1.0e-4
#
target: ldm.models.diffusion.ddpm.LatentDiffusion
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
params
:
parameterization
:
"
v"
parameterization
:
"
v"
linear_start
:
0.00085
linear_start
:
0.00085
...
@@ -29,7 +29,7 @@ model:
...
@@ -29,7 +29,7 @@ model:
unet_config
:
unet_config
:
#
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
params
:
use_checkpoint
:
True
use_checkpoint
:
True
use_fp16
:
True
use_fp16
:
True
...
@@ -48,7 +48,7 @@ model:
...
@@ -48,7 +48,7 @@ model:
legacy
:
False
legacy
:
False
first_stage_config
:
first_stage_config
:
#
target: ldm.models.autoencoder.AutoencoderKL
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
params
:
embed_dim
:
4
embed_dim
:
4
monitor
:
val/rec_loss
monitor
:
val/rec_loss
...
@@ -72,13 +72,13 @@ model:
...
@@ -72,13 +72,13 @@ model:
target
:
torch.nn.Identity
target
:
torch.nn.Identity
cond_stage_config
:
cond_stage_config
:
#
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
params
:
freeze
:
True
freeze
:
True
layer
:
"
penultimate"
layer
:
"
penultimate"
data
:
data
:
#
target: main.DataModuleFromConfig
target
:
main.DataModuleFromConfig
params
:
params
:
batch_size
:
128
batch_size
:
128
# num_workwers should be 2 * batch_size, and the total num less than 1024
# num_workwers should be 2 * batch_size, and the total num less than 1024
...
@@ -100,7 +100,7 @@ lightning:
...
@@ -100,7 +100,7 @@ lightning:
precision
:
16
precision
:
16
auto_select_gpus
:
False
auto_select_gpus
:
False
strategy
:
strategy
:
#
target: strategies.DDPStrategy
target
:
strategies.DDPStrategy
params
:
params
:
find_unused_parameters
:
False
find_unused_parameters
:
False
log_every_n_steps
:
2
log_every_n_steps
:
2
...
@@ -111,7 +111,7 @@ lightning:
...
@@ -111,7 +111,7 @@ lightning:
logger_config
:
logger_config
:
wandb
:
wandb
:
#
target: loggers.WandbLogger
target
:
loggers.WandbLogger
params
:
params
:
name
:
nowname
name
:
nowname
save_dir
:
"
/data2/tmp/diff_log/"
save_dir
:
"
/data2/tmp/diff_log/"
...
...
examples/images/diffusion/ldm/models/autoencoder.py
View file @
fb8fae6f
...
@@ -6,10 +6,11 @@ except:
...
@@ -6,10 +6,11 @@ except:
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
torch.nn
import
Identity
from
ldm.modules.diffusionmodules.model
import
Encoder
,
Decoder
from
ldm.modules.diffusionmodules.model
import
Encoder
,
Decoder
from
ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
from
ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
from
ldm.util
import
instantiate_from_config
from
ldm.modules.ema
import
LitEma
from
ldm.modules.ema
import
LitEma
...
@@ -31,7 +32,7 @@ class AutoencoderKL(pl.LightningModule):
...
@@ -31,7 +32,7 @@ class AutoencoderKL(pl.LightningModule):
self
.
image_key
=
image_key
self
.
image_key
=
image_key
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
loss
=
Identity
(
**
lossconfig
.
get
(
"params"
,
dict
())
)
self
.
loss
=
instantiate_from_config
(
lossconfig
)
assert
ddconfig
[
"double_z"
]
assert
ddconfig
[
"double_z"
]
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
ddconfig
[
"z_channels"
],
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
ddconfig
[
"z_channels"
],
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
...
...
examples/images/diffusion/ldm/models/diffusion/classifier.py
View file @
fb8fae6f
...
@@ -9,10 +9,9 @@ from copy import deepcopy
...
@@ -9,10 +9,9 @@ from copy import deepcopy
from
einops
import
rearrange
from
einops
import
rearrange
from
glob
import
glob
from
glob
import
glob
from
natsort
import
natsorted
from
natsort
import
natsorted
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.lr_scheduler
import
LambdaLinearScheduler
from
ldm.modules.diffusionmodules.openaimodel
import
EncoderUNetModel
,
UNetModel
from
ldm.modules.diffusionmodules.openaimodel
import
EncoderUNetModel
,
UNetModel
from
ldm.util
import
log_txt_as_img
,
default
,
ismap
from
ldm.util
import
log_txt_as_img
,
default
,
ismap
,
instantiate_from_config
__models__
=
{
__models__
=
{
'class_label'
:
EncoderUNetModel
,
'class_label'
:
EncoderUNetModel
,
...
@@ -87,7 +86,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
...
@@ -87,7 +86,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
print
(
f
"Unexpected Keys:
{
unexpected
}
"
)
print
(
f
"Unexpected Keys:
{
unexpected
}
"
)
def
load_diffusion
(
self
):
def
load_diffusion
(
self
):
model
=
LatentDiffusion
(
**
self
.
diffusion_config
.
get
(
'params'
,
dict
())
)
model
=
instantiate_from_config
(
self
.
diffusion_config
)
self
.
diffusion_model
=
model
.
eval
()
self
.
diffusion_model
=
model
.
eval
()
self
.
diffusion_model
.
train
=
disabled_train
self
.
diffusion_model
.
train
=
disabled_train
for
param
in
self
.
diffusion_model
.
parameters
():
for
param
in
self
.
diffusion_model
.
parameters
():
...
@@ -222,7 +221,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
...
@@ -222,7 +221,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
optimizer
=
AdamW
(
self
.
model
.
parameters
(),
lr
=
self
.
learning_rate
,
weight_decay
=
self
.
weight_decay
)
optimizer
=
AdamW
(
self
.
model
.
parameters
(),
lr
=
self
.
learning_rate
,
weight_decay
=
self
.
weight_decay
)
if
self
.
use_scheduler
:
if
self
.
use_scheduler
:
scheduler
=
LambdaLinearScheduler
(
**
self
.
scheduler_config
.
get
(
'params'
,
dict
())
)
scheduler
=
instantiate_from_config
(
self
.
scheduler_config
)
print
(
"Setting up LambdaLR scheduler..."
)
print
(
"Setting up LambdaLR scheduler..."
)
scheduler
=
[
scheduler
=
[
...
...
examples/images/diffusion/ldm/models/diffusion/ddpm.py
View file @
fb8fae6f
...
@@ -22,7 +22,6 @@ from contextlib import contextmanager, nullcontext
...
@@ -22,7 +22,6 @@ from contextlib import contextmanager, nullcontext
from
functools
import
partial
from
functools
import
partial
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
ldm.lr_scheduler
import
LambdaLinearScheduler
from
ldm.models.autoencoder
import
*
from
ldm.models.autoencoder
import
*
from
ldm.models.autoencoder
import
AutoencoderKL
,
IdentityFirstStage
from
ldm.models.autoencoder
import
AutoencoderKL
,
IdentityFirstStage
from
ldm.models.diffusion.ddim
import
*
from
ldm.models.diffusion.ddim
import
*
...
@@ -30,10 +29,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
...
@@ -30,10 +29,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
from
ldm.modules.diffusionmodules.model
import
*
from
ldm.modules.diffusionmodules.model
import
*
from
ldm.modules.diffusionmodules.model
import
Decoder
,
Encoder
,
Model
from
ldm.modules.diffusionmodules.model
import
Decoder
,
Encoder
,
Model
from
ldm.modules.diffusionmodules.openaimodel
import
*
from
ldm.modules.diffusionmodules.openaimodel
import
*
from
ldm.modules.diffusionmodules.openaimodel
import
AttentionPool2d
,
UNetModel
from
ldm.modules.diffusionmodules.openaimodel
import
AttentionPool2d
from
ldm.modules.diffusionmodules.util
import
extract_into_tensor
,
make_beta_schedule
,
noise_like
from
ldm.modules.diffusionmodules.util
import
extract_into_tensor
,
make_beta_schedule
,
noise_like
from
ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
,
normal_kl
from
ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
,
normal_kl
from
ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
from
ldm.modules.ema
import
LitEma
from
ldm.modules.ema
import
LitEma
from
ldm.modules.encoders.modules
import
*
from
ldm.modules.encoders.modules
import
*
from
ldm.util
import
count_params
,
default
,
exists
,
instantiate_from_config
,
isimage
,
ismap
,
log_txt_as_img
,
mean_flat
from
ldm.util
import
count_params
,
default
,
exists
,
instantiate_from_config
,
isimage
,
ismap
,
log_txt_as_img
,
mean_flat
...
@@ -41,7 +39,6 @@ from omegaconf import ListConfig
...
@@ -41,7 +39,6 @@ from omegaconf import ListConfig
from
torch.optim.lr_scheduler
import
LambdaLR
from
torch.optim.lr_scheduler
import
LambdaLR
from
torchvision.utils
import
make_grid
from
torchvision.utils
import
make_grid
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
ldm.modules.midas.api
import
MiDaSInference
__conditioning_keys__
=
{
'concat'
:
'c_concat'
,
'crossattn'
:
'c_crossattn'
,
'adm'
:
'y'
}
__conditioning_keys__
=
{
'concat'
:
'c_concat'
,
'crossattn'
:
'c_crossattn'
,
'adm'
:
'y'
}
...
@@ -693,7 +690,7 @@ class LatentDiffusion(DDPM):
...
@@ -693,7 +690,7 @@ class LatentDiffusion(DDPM):
self
.
make_cond_schedule
()
self
.
make_cond_schedule
()
def
instantiate_first_stage
(
self
,
config
):
def
instantiate_first_stage
(
self
,
config
):
model
=
AutoencoderKL
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
self
.
first_stage_model
=
model
.
eval
()
self
.
first_stage_model
=
model
.
eval
()
self
.
first_stage_model
.
train
=
disabled_train
self
.
first_stage_model
.
train
=
disabled_train
for
param
in
self
.
first_stage_model
.
parameters
():
for
param
in
self
.
first_stage_model
.
parameters
():
...
@@ -709,7 +706,7 @@ class LatentDiffusion(DDPM):
...
@@ -709,7 +706,7 @@ class LatentDiffusion(DDPM):
self
.
cond_stage_model
=
None
self
.
cond_stage_model
=
None
# self.be_unconditional = True
# self.be_unconditional = True
else
:
else
:
model
=
FrozenOpenCLIPEmbedder
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
self
.
cond_stage_model
=
model
.
eval
()
self
.
cond_stage_model
=
model
.
eval
()
self
.
cond_stage_model
.
train
=
disabled_train
self
.
cond_stage_model
.
train
=
disabled_train
for
param
in
self
.
cond_stage_model
.
parameters
():
for
param
in
self
.
cond_stage_model
.
parameters
():
...
@@ -717,7 +714,7 @@ class LatentDiffusion(DDPM):
...
@@ -717,7 +714,7 @@ class LatentDiffusion(DDPM):
else
:
else
:
assert
config
!=
'__is_first_stage__'
assert
config
!=
'__is_first_stage__'
assert
config
!=
'__is_unconditional__'
assert
config
!=
'__is_unconditional__'
model
=
FrozenOpenCLIPEmbedder
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
self
.
cond_stage_model
=
model
self
.
cond_stage_model
=
model
def
_get_denoise_row_from_list
(
self
,
samples
,
desc
=
''
,
force_no_decoder_quantization
=
False
):
def
_get_denoise_row_from_list
(
self
,
samples
,
desc
=
''
,
force_no_decoder_quantization
=
False
):
...
@@ -1482,7 +1479,8 @@ class LatentDiffusion(DDPM):
...
@@ -1482,7 +1479,8 @@ class LatentDiffusion(DDPM):
# opt = torch.optim.AdamW(params, lr=lr)
# opt = torch.optim.AdamW(params, lr=lr)
if
self
.
use_scheduler
:
if
self
.
use_scheduler
:
scheduler
=
LambdaLinearScheduler
(
**
self
.
scheduler_config
.
get
(
"params"
,
dict
()))
assert
'target'
in
self
.
scheduler_config
scheduler
=
instantiate_from_config
(
self
.
scheduler_config
)
rank_zero_info
(
"Setting up LambdaLR scheduler..."
)
rank_zero_info
(
"Setting up LambdaLR scheduler..."
)
scheduler
=
[{
'scheduler'
:
LambdaLR
(
opt
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
}]
scheduler
=
[{
'scheduler'
:
LambdaLR
(
opt
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
}]
...
@@ -1504,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
...
@@ -1504,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
def
__init__
(
self
,
diff_model_config
,
conditioning_key
):
def
__init__
(
self
,
diff_model_config
,
conditioning_key
):
super
().
__init__
()
super
().
__init__
()
self
.
sequential_cross_attn
=
diff_model_config
.
pop
(
"sequential_crossattn"
,
False
)
self
.
sequential_cross_attn
=
diff_model_config
.
pop
(
"sequential_crossattn"
,
False
)
self
.
diffusion_model
=
UNetModel
(
**
diff_model_config
.
get
(
"params"
,
dict
())
)
self
.
diffusion_model
=
instantiate_from_config
(
diff_model_config
)
self
.
conditioning_key
=
conditioning_key
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
...
@@ -1553,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
...
@@ -1553,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
self
.
noise_level_key
=
noise_level_key
self
.
noise_level_key
=
noise_level_key
def
instantiate_low_stage
(
self
,
config
):
def
instantiate_low_stage
(
self
,
config
):
model
=
ImageConcatWithNoiseAugmentation
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
self
.
low_scale_model
=
model
.
eval
()
self
.
low_scale_model
=
model
.
eval
()
self
.
low_scale_model
.
train
=
disabled_train
self
.
low_scale_model
.
train
=
disabled_train
for
param
in
self
.
low_scale_model
.
parameters
():
for
param
in
self
.
low_scale_model
.
parameters
():
...
@@ -1935,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
...
@@ -1935,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
def
__init__
(
self
,
depth_stage_config
,
concat_keys
=
(
"midas_in"
,),
*
args
,
**
kwargs
):
def
__init__
(
self
,
depth_stage_config
,
concat_keys
=
(
"midas_in"
,),
*
args
,
**
kwargs
):
super
().
__init__
(
concat_keys
=
concat_keys
,
*
args
,
**
kwargs
)
super
().
__init__
(
concat_keys
=
concat_keys
,
*
args
,
**
kwargs
)
self
.
depth_model
=
MiDaSInference
(
**
depth_stage_config
.
get
(
"params"
,
dict
())
)
self
.
depth_model
=
instantiate_from_config
(
depth_stage_config
)
self
.
depth_stage_key
=
concat_keys
[
0
]
self
.
depth_stage_key
=
concat_keys
[
0
]
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -2008,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
...
@@ -2008,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
self
.
low_scale_key
=
low_scale_key
self
.
low_scale_key
=
low_scale_key
def
instantiate_low_stage
(
self
,
config
):
def
instantiate_low_stage
(
self
,
config
):
model
=
ImageConcatWithNoiseAugmentation
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
self
.
low_scale_model
=
model
.
eval
()
self
.
low_scale_model
=
model
.
eval
()
self
.
low_scale_model
.
train
=
disabled_train
self
.
low_scale_model
.
train
=
disabled_train
for
param
in
self
.
low_scale_model
.
parameters
():
for
param
in
self
.
low_scale_model
.
parameters
():
...
...
examples/images/diffusion/main.py
View file @
fb8fae6f
...
@@ -23,21 +23,19 @@ from packaging import version
...
@@ -23,21 +23,19 @@ from packaging import version
from
PIL
import
Image
from
PIL
import
Image
from
prefetch_generator
import
BackgroundGenerator
from
prefetch_generator
import
BackgroundGenerator
from
torch.utils.data
import
DataLoader
,
Dataset
,
Subset
,
random_split
from
torch.utils.data
import
DataLoader
,
Dataset
,
Subset
,
random_split
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
#try:
try
:
from
lightning.pytorch
import
seed_everything
from
lightning.pytorch
import
seed_everything
from
lightning.pytorch.callbacks
import
Callback
,
LearningRateMonitor
,
ModelCheckpoint
from
lightning.pytorch.callbacks
import
Callback
,
LearningRateMonitor
,
ModelCheckpoint
from
lightning.pytorch.trainer
import
Trainer
from
lightning.pytorch.trainer
import
Trainer
from
lightning.pytorch.utilities
import
rank_zero_info
,
rank_zero_only
from
lightning.pytorch.utilities
import
rank_zero_info
,
rank_zero_only
from
lightning.pytorch.loggers
import
WandbLogger
,
TensorBoardLogger
LIGHTNING_PACK_NAME
=
"lightning.pytorch."
from
lightning.pytorch.strategies
import
ColossalAIStrategy
,
DDPStrategy
except
:
LIGHTNING_PACK_NAME
=
"lightning.pytorch."
from
pytorch_lightning
import
seed_everything
# #except:
from
pytorch_lightning.callbacks
import
Callback
,
LearningRateMonitor
,
ModelCheckpoint
# from pytorch_lightning import seed_everything
from
pytorch_lightning.trainer
import
Trainer
# from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
# from pytorch_lightning.trainer import Trainer
LIGHTNING_PACK_NAME
=
"pytorch_lightning."
# from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
# LIGHTNING_PACK_NAME = "pytorch_lightning."
from
ldm.data.base
import
Txt2ImgIterableBaseDataset
from
ldm.data.base
import
Txt2ImgIterableBaseDataset
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
...
@@ -577,7 +575,7 @@ if __name__ == "__main__":
...
@@ -577,7 +575,7 @@ if __name__ == "__main__":
# target: path to test dataset
# target: path to test dataset
# params:
# params:
# key: value
# key: value
# lightning: (optional, has sa
m
e defaults and can be specified on cmdline)
# lightning: (optional, has sa
n
e defaults and can be specified on cmdline)
# trainer:
# trainer:
# additional arguments to trainer
# additional arguments to trainer
# logger:
# logger:
...
@@ -655,7 +653,7 @@ if __name__ == "__main__":
...
@@ -655,7 +653,7 @@ if __name__ == "__main__":
# Sets the seed for the random number generator to ensure reproducibility
# Sets the seed for the random number generator to ensure reproducibility
seed_everything
(
opt
.
seed
)
seed_everything
(
opt
.
seed
)
# Intinalize and save configuration using t
h
e OmegaConf library.
# Intinalize and save configuratio
o
n using te
h
OmegaConf library.
try
:
try
:
# init and save configs
# init and save configs
configs
=
[
OmegaConf
.
load
(
cfg
)
for
cfg
in
opt
.
base
]
configs
=
[
OmegaConf
.
load
(
cfg
)
for
cfg
in
opt
.
base
]
...
@@ -689,7 +687,7 @@ if __name__ == "__main__":
...
@@ -689,7 +687,7 @@ if __name__ == "__main__":
config
.
model
[
"params"
].
update
({
"ckpt"
:
ckpt
})
config
.
model
[
"params"
].
update
({
"ckpt"
:
ckpt
})
rank_zero_info
(
"Using ckpt_path = {}"
.
format
(
config
.
model
[
"params"
][
"ckpt"
]))
rank_zero_info
(
"Using ckpt_path = {}"
.
format
(
config
.
model
[
"params"
][
"ckpt"
]))
model
=
LatentDiffusion
(
**
config
.
model
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
.
model
)
# trainer and callbacks
# trainer and callbacks
trainer_kwargs
=
dict
()
trainer_kwargs
=
dict
()
...
@@ -698,7 +696,7 @@ if __name__ == "__main__":
...
@@ -698,7 +696,7 @@ if __name__ == "__main__":
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
default_logger_cfgs
=
{
default_logger_cfgs
=
{
"wandb"
:
{
"wandb"
:
{
#
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"target"
:
LIGHTNING_PACK_NAME
+
"loggers.WandbLogger"
,
"params"
:
{
"params"
:
{
"name"
:
nowname
,
"name"
:
nowname
,
"save_dir"
:
logdir
,
"save_dir"
:
logdir
,
...
@@ -707,7 +705,7 @@ if __name__ == "__main__":
...
@@ -707,7 +705,7 @@ if __name__ == "__main__":
}
}
},
},
"tensorboard"
:
{
"tensorboard"
:
{
#
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"target"
:
LIGHTNING_PACK_NAME
+
"loggers.TensorBoardLogger"
,
"params"
:
{
"params"
:
{
"save_dir"
:
logdir
,
"save_dir"
:
logdir
,
"name"
:
"diff_tb"
,
"name"
:
"diff_tb"
,
...
@@ -720,32 +718,30 @@ if __name__ == "__main__":
...
@@ -720,32 +718,30 @@ if __name__ == "__main__":
default_logger_cfg
=
default_logger_cfgs
[
"tensorboard"
]
default_logger_cfg
=
default_logger_cfgs
[
"tensorboard"
]
if
"logger"
in
lightning_config
:
if
"logger"
in
lightning_config
:
logger_cfg
=
lightning_config
.
logger
logger_cfg
=
lightning_config
.
logger
logger_cfg
=
OmegaConf
.
merge
(
default_logger_cfg
,
logger_cfg
)
trainer_kwargs
[
"logger"
]
=
WandbLogger
(
**
logger_cfg
.
get
(
"params"
,
dict
()))
else
:
else
:
logger_cfg
=
default_logger_cfg
logger_cfg
=
default_logger_cfg
logger_cfg
=
OmegaConf
.
merge
(
default_logger_cfg
,
logger_cfg
)
logger_cfg
=
OmegaConf
.
merge
(
default_logger_cfg
,
logger_cfg
)
trainer_kwargs
[
"logger"
]
=
TensorBoardLogger
(
**
logger_cfg
.
get
(
"params"
,
dict
()))
trainer_kwargs
[
"logger"
]
=
instantiate_from_config
(
logger_cfg
)
# config the strategy, defualt is ddp
# config the strategy, defualt is ddp
if
"strategy"
in
trainer_config
:
if
"strategy"
in
trainer_config
:
strategy_cfg
=
trainer_config
[
"strategy"
]
strategy_cfg
=
trainer_config
[
"strategy"
]
tra
iner_kwargs
[
"strategy"
]
=
ColossalAIStrategy
(
**
strategy_cfg
.
get
(
"params"
,
dict
()))
s
tra
tegy_cfg
[
"target"
]
=
LIGHTNING_PACK_NAME
+
strategy_cfg
[
"target"
]
else
:
else
:
strategy_cfg
=
{
strategy_cfg
=
{
#
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
"target"
:
LIGHTNING_PACK_NAME
+
"strategies.DDPStrategy"
,
"params"
:
{
"params"
:
{
"find_unused_parameters"
:
False
"find_unused_parameters"
:
False
}
}
}
}
trainer_kwargs
[
"strategy"
]
=
DDPStrategy
(
**
strategy_cfg
.
get
(
"params"
,
dict
()))
trainer_kwargs
[
"strategy"
]
=
instantiate_from_config
(
strategy_cfg
)
# Set up ModelCheckpoint callback to save best models
# Set up ModelCheckpoint callback to save best models
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
# specify which metric is used to determine best models
default_modelckpt_cfg
=
{
default_modelckpt_cfg
=
{
#
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"target"
:
LIGHTNING_PACK_NAME
+
"callbacks.ModelCheckpoint"
,
"params"
:
{
"params"
:
{
"dirpath"
:
ckptdir
,
"dirpath"
:
ckptdir
,
"filename"
:
"{epoch:06}"
,
"filename"
:
"{epoch:06}"
,
...
@@ -763,13 +759,13 @@ if __name__ == "__main__":
...
@@ -763,13 +759,13 @@ if __name__ == "__main__":
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
merge
(
default_modelckpt_cfg
,
modelckpt_cfg
)
modelckpt_cfg
=
OmegaConf
.
merge
(
default_modelckpt_cfg
,
modelckpt_cfg
)
if
version
.
parse
(
pl
.
__version__
)
<
version
.
parse
(
'1.4.0'
):
if
version
.
parse
(
pl
.
__version__
)
<
version
.
parse
(
'1.4.0'
):
trainer_kwargs
[
"checkpoint_callback"
]
=
ModelCheckpoint
(
**
modelckpt_cfg
.
get
(
"params"
,
dict
())
)
trainer_kwargs
[
"checkpoint_callback"
]
=
instantiate_from_config
(
modelckpt_cfg
)
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
# add callback which sets up log directory
# add callback which sets up log directory
default_callbacks_cfg
=
{
default_callbacks_cfg
=
{
"setup_callback"
:
{
# callback to set up the training
"setup_callback"
:
{
# callback to set up the training
#
"target": "main.SetupCallback",
"target"
:
"main.SetupCallback"
,
"params"
:
{
"params"
:
{
"resume"
:
opt
.
resume
,
# resume training if applicable
"resume"
:
opt
.
resume
,
# resume training if applicable
"now"
:
now
,
"now"
:
now
,
...
@@ -781,7 +777,7 @@ if __name__ == "__main__":
...
@@ -781,7 +777,7 @@ if __name__ == "__main__":
}
}
},
},
"image_logger"
:
{
# callback to log image data
"image_logger"
:
{
# callback to log image data
#
"target": "main.ImageLogger",
"target"
:
"main.ImageLogger"
,
"params"
:
{
"params"
:
{
"batch_frequency"
:
750
,
# how frequently to log images
"batch_frequency"
:
750
,
# how frequently to log images
"max_images"
:
4
,
# maximum number of images to log
"max_images"
:
4
,
# maximum number of images to log
...
@@ -789,14 +785,14 @@ if __name__ == "__main__":
...
@@ -789,14 +785,14 @@ if __name__ == "__main__":
}
}
},
},
"learning_rate_logger"
:
{
# callback to log learning rate
"learning_rate_logger"
:
{
# callback to log learning rate
#
"target": "main.LearningRateMonitor",
"target"
:
"main.LearningRateMonitor"
,
"params"
:
{
"params"
:
{
"logging_interval"
:
"step"
,
# logging frequency (either 'step' or 'epoch')
"logging_interval"
:
"step"
,
# logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
# "log_momentum": True # whether to log momentum (currently commented out)
}
}
},
},
"cuda_callback"
:
{
# callback to handle CUDA-related operations
"cuda_callback"
:
{
# callback to handle CUDA-related operations
#
"target": "main.CUDACallback"
"target"
:
"main.CUDACallback"
},
},
}
}
...
@@ -814,7 +810,7 @@ if __name__ == "__main__":
...
@@ -814,7 +810,7 @@ if __name__ == "__main__":
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict
=
{
default_metrics_over_trainsteps_ckpt_dict
=
{
'metrics_over_trainsteps_checkpoint'
:
{
'metrics_over_trainsteps_checkpoint'
:
{
#
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
"target"
:
LIGHTNING_PACK_NAME
+
'callbacks.ModelCheckpoint'
,
'params'
:
{
'params'
:
{
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"filename"
:
"{epoch:06}-{step:09}"
,
"filename"
:
"{epoch:06}-{step:09}"
,
...
@@ -830,34 +826,14 @@ if __name__ == "__main__":
...
@@ -830,34 +826,14 @@ if __name__ == "__main__":
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
callbacks_cfg
=
OmegaConf
.
merge
(
default_callbacks_cfg
,
callbacks_cfg
)
callbacks_cfg
=
OmegaConf
.
merge
(
default_callbacks_cfg
,
callbacks_cfg
)
#Instantiate items according to the configs
trainer_kwargs
[
"callbacks"
]
=
[
instantiate_from_config
(
callbacks_cfg
[
k
])
for
k
in
callbacks_cfg
]
trainer_kwargs
.
setdefault
(
"callbacks"
,
[])
if
"setup_callback"
in
callbacks_cfg
:
setup_callback_config
=
callbacks_cfg
[
"setup_callback"
]
trainer_kwargs
[
"callbacks"
].
append
(
SetupCallback
(
**
setup_callback_config
.
get
(
"params"
,
dict
())))
if
"image_logger"
in
callbacks_cfg
:
image_logger_config
=
callbacks_cfg
[
"image_logger"
]
trainer_kwargs
[
"callbacks"
].
append
(
ImageLogger
(
**
image_logger_config
.
get
(
"params"
,
dict
())))
if
"learning_rate_logger"
in
callbacks_cfg
:
learning_rate_logger_config
=
callbacks_cfg
[
"learning_rate_logger"
]
trainer_kwargs
[
"callbacks"
].
append
(
LearningRateMonitor
(
**
learning_rate_logger_config
.
get
(
"params"
,
dict
())))
if
"cuda_callback"
in
callbacks_cfg
:
cuda_callback_config
=
callbacks_cfg
[
"cuda_callback"
]
trainer_kwargs
[
"callbacks"
].
append
(
CUDACallback
(
**
cuda_callback_config
.
get
(
"params"
,
dict
())))
if
"metrics_over_trainsteps_checkpoint"
in
callbacks_cfg
:
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
metrics_over_config
=
callbacks_cfg
[
'metrics_over_trainsteps_checkpoint'
]
trainer_kwargs
[
"callbacks"
].
append
(
ModelCheckpoint
(
**
metrics_over_config
.
get
(
"params"
,
dict
())))
#trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer
=
Trainer
.
from_argparse_args
(
trainer_opt
,
**
trainer_kwargs
)
trainer
=
Trainer
.
from_argparse_args
(
trainer_opt
,
**
trainer_kwargs
)
trainer
.
logdir
=
logdir
trainer
.
logdir
=
logdir
# Create a data module based on the configuration file
# Create a data module based on the configuration file
data
=
DataModuleF
rom
C
onfig
(
**
config
.
data
.
get
(
"params"
,
dict
())
)
data
=
instantiate_f
rom
_c
onfig
(
config
.
data
)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
# lightning still takes care of proper multiprocessing though
...
...
examples/images/diffusion/scripts/img2img.py
View file @
fb8fae6f
...
@@ -20,8 +20,8 @@ from imwatermark import WatermarkEncoder
...
@@ -20,8 +20,8 @@ from imwatermark import WatermarkEncoder
from
scripts.txt2img
import
put_watermark
from
scripts.txt2img
import
put_watermark
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
utils
import
replace_module
,
getModelSize
from
utils
import
replace_module
,
getModelSize
...
@@ -36,7 +36,7 @@ def load_model_from_config(config, ckpt, verbose=False):
...
@@ -36,7 +36,7 @@ def load_model_from_config(config, ckpt, verbose=False):
if
"global_step"
in
pl_sd
:
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
sd
=
pl_sd
[
"state_dict"
]
model
=
LatentDiffusion
(
**
config
.
model
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
"missing keys:"
)
...
...
examples/images/diffusion/scripts/inpaint.py
View file @
fb8fae6f
...
@@ -4,7 +4,7 @@ from PIL import Image
...
@@ -4,7 +4,7 @@ from PIL import Image
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
ldm.models.diffusion.ddpm
import
LatentgDiffusion
from
main
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
print
(
f
"Found
{
len
(
masks
)
}
inputs."
)
print
(
f
"Found
{
len
(
masks
)
}
inputs."
)
config
=
OmegaConf
.
load
(
"models/ldm/inpainting_big/config.yaml"
)
config
=
OmegaConf
.
load
(
"models/ldm/inpainting_big/config.yaml"
)
model
=
LatentDiffusion
(
**
config
.
model
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
.
model
)
model
.
load_state_dict
(
torch
.
load
(
"models/ldm/inpainting_big/last.ckpt"
)[
"state_dict"
],
model
.
load_state_dict
(
torch
.
load
(
"models/ldm/inpainting_big/last.ckpt"
)[
"state_dict"
],
strict
=
False
)
strict
=
False
)
...
...
examples/images/diffusion/scripts/knn2img.py
View file @
fb8fae6f
...
@@ -13,10 +13,9 @@ import scann
...
@@ -13,10 +13,9 @@ import scann
import
time
import
time
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
from
ldm.util
import
parallel_data_prefetch
from
ldm.util
import
instantiate_from_config
,
parallel_data_prefetch
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.modules.encoders.modules
import
FrozenClipImageEmbedder
,
FrozenCLIPTextEmbedder
from
ldm.modules.encoders.modules
import
FrozenClipImageEmbedder
,
FrozenCLIPTextEmbedder
DATABASES
=
[
DATABASES
=
[
...
@@ -45,7 +44,7 @@ def load_model_from_config(config, ckpt, verbose=False):
...
@@ -45,7 +44,7 @@ def load_model_from_config(config, ckpt, verbose=False):
if
"global_step"
in
pl_sd
:
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
sd
=
pl_sd
[
"state_dict"
]
model
=
LatentDiffusion
(
**
config
.
model
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
"missing keys:"
)
...
...
examples/images/diffusion/scripts/sample_diffusion.py
View file @
fb8fae6f
...
@@ -8,6 +8,7 @@ from omegaconf import OmegaConf
...
@@ -8,6 +8,7 @@ from omegaconf import OmegaConf
from
PIL
import
Image
from
PIL
import
Image
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.util
import
instantiate_from_config
rescale
=
lambda
x
:
(
x
+
1.
)
/
2.
rescale
=
lambda
x
:
(
x
+
1.
)
/
2.
...
@@ -217,7 +218,7 @@ def get_parser():
...
@@ -217,7 +218,7 @@ def get_parser():
def
load_model_from_config
(
config
,
sd
):
def
load_model_from_config
(
config
,
sd
):
model
=
LatentDiffusion
(
**
config
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
model
.
cuda
()
model
.
cuda
()
model
.
eval
()
model
.
eval
()
...
...
examples/images/diffusion/scripts/tests/test_checkpoint.py
View file @
fb8fae6f
...
@@ -9,7 +9,7 @@ from diffusers import StableDiffusionPipeline
...
@@ -9,7 +9,7 @@ from diffusers import StableDiffusionPipeline
import
torch
import
torch
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
from
main
import
get_parser
from
main
import
get_parser
from
ldm.modules.diffusionmodules.openaimodel
import
UNetModel
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
yaml_path
=
"../../train_colossalai.yaml"
yaml_path
=
"../../train_colossalai.yaml"
...
@@ -17,7 +17,7 @@ if __name__ == "__main__":
...
@@ -17,7 +17,7 @@ if __name__ == "__main__":
config
=
f
.
read
()
config
=
f
.
read
()
base_config
=
yaml
.
load
(
config
,
Loader
=
yaml
.
FullLoader
)
base_config
=
yaml
.
load
(
config
,
Loader
=
yaml
.
FullLoader
)
unet_config
=
base_config
[
'model'
][
'params'
][
'unet_config'
]
unet_config
=
base_config
[
'model'
][
'params'
][
'unet_config'
]
diffusion_model
=
UNetModel
(
**
unet_config
.
get
(
"params"
,
dict
())
).
to
(
"cuda:0"
)
diffusion_model
=
instantiate_from_config
(
unet_config
).
to
(
"cuda:0"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"/data/scratch/diffuser/stable-diffusion-v1-4"
"/data/scratch/diffuser/stable-diffusion-v1-4"
...
...
examples/images/diffusion/scripts/txt2img.py
View file @
fb8fae6f
...
@@ -16,9 +16,9 @@ from torch import autocast
...
@@ -16,9 +16,9 @@ from torch import autocast
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
imwatermark
import
WatermarkEncoder
from
imwatermark
import
WatermarkEncoder
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.models.diffusion.dpm_solver
import
DPMSolverSampler
from
ldm.models.diffusion.dpm_solver
import
DPMSolverSampler
from
utils
import
replace_module
,
getModelSize
from
utils
import
replace_module
,
getModelSize
...
@@ -35,7 +35,7 @@ def load_model_from_config(config, ckpt, verbose=False):
...
@@ -35,7 +35,7 @@ def load_model_from_config(config, ckpt, verbose=False):
if
"global_step"
in
pl_sd
:
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
sd
=
pl_sd
[
"state_dict"
]
model
=
LatentDiffusion
(
**
config
.
model
.
get
(
"params"
,
dict
())
)
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
"missing keys:"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment