Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b5dbb461
Unverified
Commit
b5dbb461
authored
Nov 20, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Nov 20, 2022
Browse files
[example] add diffusion inference (#1986)
parent
a01278e8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
344 additions
and
46 deletions
+344
-46
examples/images/diffusion/README.md
examples/images/diffusion/README.md
+46
-2
examples/images/diffusion/configs/train_colossalai_teyvat.yaml
...les/images/diffusion/configs/train_colossalai_teyvat.yaml
+122
-0
examples/images/diffusion/ldm/data/teyvat.py
examples/images/diffusion/ldm/data/teyvat.py
+152
-0
examples/images/diffusion/ldm/models/diffusion/ddpm.py
examples/images/diffusion/ldm/models/diffusion/ddpm.py
+18
-44
examples/images/diffusion/scripts/download_first_stages.sh
examples/images/diffusion/scripts/download_first_stages.sh
+0
-0
examples/images/diffusion/scripts/download_models.sh
examples/images/diffusion/scripts/download_models.sh
+0
-0
examples/images/diffusion/scripts/txt2img.sh
examples/images/diffusion/scripts/txt2img.sh
+6
-0
No files found.
examples/images/diffusion/README.md
View file @
b5dbb461
...
...
@@ -96,11 +96,55 @@ We provide the finetuning example on CIFAR10 dataset
You can run by config
`train_colossalai_cifar10.yaml`
```
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
```
## Inference
you can get yout training last.ckpt and train config.yaml in your
`--logdir`
, and run by
```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \
--config path/to/logdir/checkpoints/last.ckpt \
--ckpt /path/to/logdir/configs/project.yaml \
```
```
commandline
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
[--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
[--seed SEED] [--precision {full,autocast}]
optional arguments:
-h, --help show this help message and exit
--prompt [PROMPT] the prompt to render
--outdir [OUTDIR] dir to write results to
--skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
--skip_save do not save individual samples. For speed measurements.
--ddim_steps DDIM_STEPS
number of ddim sampling steps
--plms use plms sampling
--laion400m uses the LAION400M model
--fixed_code if enabled, uses the same starting code across samples
--ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
--n_iter N_ITER sample this often
--H H image height, in pixel space
--W W image width, in pixel space
--C C latent channels
--f F downsampling factor
--n_samples N_SAMPLES
how many samples to produce for each given prompt. A.k.a. batch size
--n_rows N_ROWS rows in the grid (default: n_samples)
--scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
--from-file FROM_FILE
if specified, load prompts from this file
--config CONFIG path to config which constructs model
--ckpt CKPT path to checkpoint of model
--seed SEED the seed (for reproducible sampling)
--precision {full,autocast}
evaluate at this precision
```
## Comments
-
Our codebase for the diffusion models builds heavily on
[
OpenAI's ADM codebase
](
https://github.com/openai/guided-diffusion
)
...
...
examples/images/diffusion/configs/train_colossalai_teyvat.yaml
0 → 100644
View file @
b5dbb461
model
:
base_learning_rate
:
1.0e-04
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.00085
linear_end
:
0.0120
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
image
cond_stage_key
:
txt
image_size
:
64
channels
:
4
cond_stage_trainable
:
false
# Note: different from the one we trained before
conditioning_key
:
crossattn
monitor
:
val/loss_simple_ema
scale_factor
:
0.18215
use_ema
:
False
scheduler_config
:
# 10000 warmup steps
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
warm_up_steps
:
[
1
]
# NOTE for resuming. use 10000 if starting from scratch
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
f_start
:
[
1.e-6
]
f_max
:
[
1.e-4
]
f_min
:
[
1.e-10
]
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
32
# unused
from_pretrained
:
'
/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels
:
4
out_channels
:
4
model_channels
:
320
attention_resolutions
:
[
4
,
2
,
1
]
num_res_blocks
:
2
channel_mult
:
[
1
,
2
,
4
,
4
]
num_heads
:
8
use_spatial_transformer
:
True
transformer_depth
:
1
context_dim
:
768
use_checkpoint
:
False
legacy
:
False
first_stage_config
:
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
embed_dim
:
4
from_pretrained
:
'
/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor
:
val/rec_loss
ddconfig
:
double_z
:
true
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.FrozenCLIPEmbedder
params
:
use_fp16
:
True
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
16
num_workers
:
4
train
:
target
:
ldm.data.teyvat.hf_dataset
params
:
path
:
Fazzie/Teyvat
image_transforms
:
-
target
:
torchvision.transforms.Resize
params
:
size
:
512
# - target: torchvision.transforms.RandomCrop
# params:
# size: 256
# - target: torchvision.transforms.RandomHorizontalFlip
lightning
:
trainer
:
accelerator
:
'
gpu'
devices
:
2
log_gpu_memory
:
all
max_epochs
:
10
precision
:
16
auto_select_gpus
:
False
strategy
:
target
:
lightning.pytorch.strategies.ColossalAIStrategy
params
:
use_chunk
:
False
enable_distributed_storage
:
True,
placement_policy
:
cuda
force_outputs_fp32
:
False
log_every_n_steps
:
2
logger
:
True
default_root_dir
:
"
/tmp/diff_log/"
profiler
:
pytorch
logger_config
:
wandb
:
target
:
lightning.pytorch.loggers.WandbLogger
params
:
name
:
nowname
save_dir
:
"
/tmp/diff_log/"
offline
:
opt.debug
id
:
nowname
\ No newline at end of file
examples/images/diffusion/ldm/data/teyvat.py
0 → 100644
View file @
b5dbb461
from
typing
import
Dict
import
numpy
as
np
from
omegaconf
import
DictConfig
,
ListConfig
import
torch
from
torch.utils.data
import
Dataset
from
pathlib
import
Path
import
json
from
PIL
import
Image
from
torchvision
import
transforms
from
einops
import
rearrange
from
ldm.util
import
instantiate_from_config
from
datasets
import
load_dataset
def
make_multi_folder_data
(
paths
,
caption_files
=
None
,
**
kwargs
):
"""Make a concat dataset from multiple folders
Don't suport captions yet
If paths is a list, that's ok, if it's a Dict interpret it as:
k=folder v=n_times to repeat that
"""
list_of_paths
=
[]
if
isinstance
(
paths
,
(
Dict
,
DictConfig
)):
assert
caption_files
is
None
,
\
"Caption files not yet supported for repeats"
for
folder_path
,
repeats
in
paths
.
items
():
list_of_paths
.
extend
([
folder_path
]
*
repeats
)
paths
=
list_of_paths
if
caption_files
is
not
None
:
datasets
=
[
FolderData
(
p
,
caption_file
=
c
,
**
kwargs
)
for
(
p
,
c
)
in
zip
(
paths
,
caption_files
)]
else
:
datasets
=
[
FolderData
(
p
,
**
kwargs
)
for
p
in
paths
]
return
torch
.
utils
.
data
.
ConcatDataset
(
datasets
)
class
FolderData
(
Dataset
):
def
__init__
(
self
,
root_dir
,
caption_file
=
None
,
image_transforms
=
[],
ext
=
"jpg"
,
default_caption
=
""
,
postprocess
=
None
,
return_paths
=
False
,
)
->
None
:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self
.
root_dir
=
Path
(
root_dir
)
self
.
default_caption
=
default_caption
self
.
return_paths
=
return_paths
if
isinstance
(
postprocess
,
DictConfig
):
postprocess
=
instantiate_from_config
(
postprocess
)
self
.
postprocess
=
postprocess
if
caption_file
is
not
None
:
with
open
(
caption_file
,
"rt"
)
as
f
:
ext
=
Path
(
caption_file
).
suffix
.
lower
()
if
ext
==
".json"
:
captions
=
json
.
load
(
f
)
elif
ext
==
".jsonl"
:
lines
=
f
.
readlines
()
lines
=
[
json
.
loads
(
x
)
for
x
in
lines
]
captions
=
{
x
[
"file_name"
]:
x
[
"text"
].
strip
(
"
\n
"
)
for
x
in
lines
}
else
:
raise
ValueError
(
f
"Unrecognised format:
{
ext
}
"
)
self
.
captions
=
captions
else
:
self
.
captions
=
None
if
not
isinstance
(
ext
,
(
tuple
,
list
,
ListConfig
)):
ext
=
[
ext
]
# Only used if there is no caption file
self
.
paths
=
[]
for
e
in
ext
:
self
.
paths
.
extend
(
list
(
self
.
root_dir
.
rglob
(
f
"*.
{
e
}
"
)))
if
isinstance
(
image_transforms
,
ListConfig
):
image_transforms
=
[
instantiate_from_config
(
tt
)
for
tt
in
image_transforms
]
image_transforms
.
extend
([
transforms
.
ToTensor
(),
transforms
.
Lambda
(
lambda
x
:
rearrange
(
x
*
2.
-
1.
,
'c h w -> h w c'
))])
image_transforms
=
transforms
.
Compose
(
image_transforms
)
self
.
tform
=
image_transforms
def
__len__
(
self
):
if
self
.
captions
is
not
None
:
return
len
(
self
.
captions
.
keys
())
else
:
return
len
(
self
.
paths
)
def
__getitem__
(
self
,
index
):
data
=
{}
if
self
.
captions
is
not
None
:
chosen
=
list
(
self
.
captions
.
keys
())[
index
]
caption
=
self
.
captions
.
get
(
chosen
,
None
)
if
caption
is
None
:
caption
=
self
.
default_caption
filename
=
self
.
root_dir
/
chosen
else
:
filename
=
self
.
paths
[
index
]
if
self
.
return_paths
:
data
[
"path"
]
=
str
(
filename
)
im
=
Image
.
open
(
filename
)
im
=
self
.
process_im
(
im
)
data
[
"image"
]
=
im
if
self
.
captions
is
not
None
:
data
[
"txt"
]
=
caption
else
:
data
[
"txt"
]
=
self
.
default_caption
if
self
.
postprocess
is
not
None
:
data
=
self
.
postprocess
(
data
)
return
data
def
process_im
(
self
,
im
):
im
=
im
.
convert
(
"RGB"
)
return
self
.
tform
(
im
)
def
hf_dataset
(
path
=
"Fazzie/Teyvat"
,
image_transforms
=
[],
image_column
=
"image"
,
text_column
=
"text"
,
image_key
=
'image'
,
caption_key
=
'txt'
,
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
ds
=
load_dataset
(
path
,
name
=
"train"
)
ds
=
ds
[
"train"
]
image_transforms
=
[
instantiate_from_config
(
tt
)
for
tt
in
image_transforms
]
image_transforms
.
extend
([
transforms
.
Resize
((
256
,
256
)),
transforms
.
ToTensor
(),
transforms
.
Lambda
(
lambda
x
:
rearrange
(
x
*
2.
-
1.
,
'c h w -> h w c'
))]
)
tform
=
transforms
.
Compose
(
image_transforms
)
assert
image_column
in
ds
.
column_names
,
f
"Didn't find column
{
image_column
}
in
{
ds
.
column_names
}
"
assert
text_column
in
ds
.
column_names
,
f
"Didn't find column
{
text_column
}
in
{
ds
.
column_names
}
"
def
pre_process
(
examples
):
processed
=
{}
processed
[
image_key
]
=
[
tform
(
im
)
for
im
in
examples
[
image_column
]]
processed
[
caption_key
]
=
examples
[
text_column
]
return
processed
ds
.
set_transform
(
pre_process
)
return
ds
\ No newline at end of file
examples/images/diffusion/ldm/models/diffusion/ddpm.py
View file @
b5dbb461
...
...
@@ -99,12 +99,12 @@ class DDPM(pl.LightningModule):
self
.
use_positional_encodings
=
use_positional_encodings
self
.
unet_config
=
unet_config
self
.
conditioning_key
=
conditioning_key
#
self.model = DiffusionWrapper(unet_config, conditioning_key)
#
count_params(self.model, verbose=True)
self
.
model
=
DiffusionWrapper
(
unet_config
,
conditioning_key
)
count_params
(
self
.
model
,
verbose
=
True
)
self
.
use_ema
=
use_ema
#
if self.use_ema:
#
self.model_ema = LitEma(self.model)
#
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
.
model
)
print
(
f
"Keeping EMAs of
{
len
(
list
(
self
.
model_ema
.
buffers
()))
}
."
)
self
.
use_scheduler
=
scheduler_config
is
not
None
if
self
.
use_scheduler
:
...
...
@@ -125,20 +125,20 @@ class DDPM(pl.LightningModule):
self
.
linear_start
=
linear_start
self
.
linear_end
=
linear_end
self
.
cosine_s
=
cosine_s
#
if ckpt_path is not None:
#
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
#
#
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
#
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
,
only_model
=
load_only_unet
)
self
.
register_schedule
(
given_betas
=
given_betas
,
beta_schedule
=
beta_schedule
,
timesteps
=
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
self
.
loss_type
=
loss_type
self
.
learn_logvar
=
learn_logvar
self
.
logvar_init
=
logvar_init
#
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
#
if self.learn_logvar:
#
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
#
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self
.
logvar
=
torch
.
full
(
fill_value
=
logvar_init
,
size
=
(
self
.
num_timesteps
,))
if
self
.
learn_logvar
:
self
.
logvar
=
nn
.
Parameter
(
self
.
logvar
,
requires_grad
=
True
)
self
.
logvar
=
nn
.
Parameter
(
self
.
logvar
,
requires_grad
=
True
)
self
.
use_fp16
=
use_fp16
if
use_fp16
:
...
...
@@ -312,14 +312,6 @@ class DDPM(pl.LightningModule):
def
get_loss
(
self
,
pred
,
target
,
mean
=
True
):
if
pred
.
isnan
().
any
():
print
(
"Warning: Prediction has nan values"
)
lr
=
self
.
optimizers
().
param_groups
[
0
][
'lr'
]
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
print
(
f
"lr:
{
lr
}
"
)
if
pred
.
isinf
().
any
():
print
(
"Warning: Prediction has inf values"
)
if
self
.
use_fp16
:
target
=
target
.
half
()
...
...
@@ -334,15 +326,6 @@ class DDPM(pl.LightningModule):
loss
=
torch
.
nn
.
functional
.
mse_loss
(
target
,
pred
,
reduction
=
'none'
)
else
:
raise
NotImplementedError
(
"unknown loss type '{loss_type}'"
)
if
loss
.
isnan
().
any
():
print
(
"Warning: loss has nan values"
)
print
(
"loss: "
,
loss
[
0
][
0
][
0
])
raise
ValueError
(
"loss has nan values"
)
if
loss
.
isinf
().
any
():
print
(
"Warning: loss has inf values"
)
print
(
"loss: "
,
loss
)
raise
ValueError
(
"loss has inf values"
)
return
loss
...
...
@@ -382,11 +365,7 @@ class DDPM(pl.LightningModule):
return
self
.
p_losses
(
x
,
t
,
*
args
,
**
kwargs
)
def
get_input
(
self
,
batch
,
k
):
# print("+" * 30)
# print(batch['jpg'].shape)
# print(len(batch['txt']))
# print(k)
# print("=" * 30)
if
not
isinstance
(
batch
,
torch
.
Tensor
):
x
=
batch
[
k
]
else
:
...
...
@@ -534,8 +513,8 @@ class LatentDiffusion(DDPM):
else
:
self
.
cond_stage_config
[
"params"
].
update
({
"use_fp16"
:
False
})
rank_zero_info
(
"Using fp16 for conditioning stage = {}"
.
format
(
self
.
cond_stage_config
[
"params"
][
"use_fp16"
]))
#
self.instantiate_first_stage(first_stage_config)
#
self.instantiate_cond_stage(cond_stage_config)
self
.
instantiate_first_stage
(
first_stage_config
)
self
.
instantiate_cond_stage
(
cond_stage_config
)
self
.
cond_stage_forward
=
cond_stage_forward
self
.
clip_denoised
=
False
self
.
bbox_tokenizer
=
None
...
...
@@ -561,16 +540,11 @@ class LatentDiffusion(DDPM):
self
.
logvar
=
torch
.
full
(
fill_value
=
self
.
logvar_init
,
size
=
(
self
.
num_timesteps
,))
if
self
.
learn_logvar
:
self
.
logvar
=
nn
.
Parameter
(
self
.
logvar
,
requires_grad
=
True
)
#
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self
.
logvar
=
nn
.
Parameter
(
self
.
logvar
,
requires_grad
=
True
)
if
self
.
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
self
.
ckpt_path
,
self
.
ignore_keys
)
self
.
restarted_from_ckpt
=
True
# TODO()
# for p in self.model.modules():
# if not p.parameters().data.is_contiguous:
# p.data = p.data.contiguous()
self
.
instantiate_first_stage
(
self
.
first_stage_config
)
self
.
instantiate_cond_stage
(
self
.
cond_stage_config
)
...
...
examples/images/diffusion/scripts/download_first_stages.sh
100644 → 100755
View file @
b5dbb461
File mode changed from 100644 to 100755
examples/images/diffusion/scripts/download_models.sh
100644 → 100755
View file @
b5dbb461
File mode changed from 100644 to 100755
examples/images/diffusion/scripts/txt2img.sh
0 → 100755
View file @
b5dbb461
python scripts/txt2img.py
--prompt
"Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword"
--plms
\
--outdir
./output
\
--config
/home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt
\
--ckpt
/home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml
\
--n_samples
4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment