Unverified Commit 6c4c6a04 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Merge pull request #2120 from Fazziekey/example/stablediffusion-v2

[example] support stable diffusion v2
parents 5efda697 cea4292a
# Stable Diffusion with Colossal-AI # ColoDiffusion: Stable Diffusion with Colossal-AI
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and *[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
...@@ -6,6 +7,7 @@ We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to ...@@ -6,6 +7,7 @@ We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs. , e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
## Stable Diffusion ## Stable Diffusion
[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
model. model.
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
...@@ -23,6 +25,7 @@ this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on te ...@@ -23,6 +25,7 @@ this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on te
</p> </p>
## Requirements ## Requirements
A suitable [conda](https://conda.io/) environment named `ldm` can be created A suitable [conda](https://conda.io/) environment named `ldm` can be created
and activated with: and activated with:
...@@ -34,14 +37,24 @@ conda activate ldm ...@@ -34,14 +37,24 @@ conda activate ldm
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
``` ```
conda install pytorch torchvision -c pytorch conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark pip install transformers==4.19.2 diffusers invisible-watermark
pip install -e . pip install -e .
``` ```
### install lightning
```
git clone https://github.com/1SAA/lightning.git
git checkout strategy/colossalai
export PACKAGE_NAME=pytorch
pip install .
```
### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website ### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
``` ```
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
``` ```
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. > The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
...@@ -49,6 +62,7 @@ pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org ...@@ -49,6 +62,7 @@ pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
## Download the model checkpoint from pretrained ## Download the model checkpoint from pretrained
### stable-diffusion-v1-4 ### stable-diffusion-v1-4
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style) Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
``` ```
...@@ -57,6 +71,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 ...@@ -57,6 +71,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
``` ```
### stable-diffusion-v1-5 from runway ### stable-diffusion-v1-5 from runway
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
``` ```
...@@ -64,23 +79,24 @@ git lfs install ...@@ -64,23 +79,24 @@ git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
``` ```
## Dataset ## Dataset
The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
you should the change the `data.file_path` in the `config/train_colossalai.yaml` you should the change the `data.file_path` in the `config/train_colossalai.yaml`
## Training ## Training
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml`
For example, you can run the training from colossalai by For example, you can run the training from colossalai by
``` ```
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
``` ```
- you can change the `--logdir` the save the log information and the last checkpoint - you can change the `--logdir` the save the log information and the last checkpoint
### Training config ### Training config
You can change the trainging config in the yaml file You can change the trainging config in the yaml file
- accelerator: acceleratortype, default 'gpu' - accelerator: acceleratortype, default 'gpu'
...@@ -88,27 +104,25 @@ You can change the trainging config in the yaml file ...@@ -88,27 +104,25 @@ You can change the trainging config in the yaml file
- max_epochs: max training epochs - max_epochs: max training epochs
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai - precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
## Example ## Finetone Example
### Training on Teyvat Datasets
### Training on cifar10
We provide the finetuning example on CIFAR10 dataset We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
You can run by config `train_colossalai_cifar10.yaml` You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml`
``` ```
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml
``` ```
## Inference ## Inference
you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by
``` ```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \ --outdir ./output \
--config path/to/logdir/checkpoints/last.ckpt \ --config path/to/logdir/checkpoints/last.ckpt \
--ckpt /path/to/logdir/configs/project.yaml \ --ckpt /path/to/logdir/configs/project.yaml \
``` ```
```commandline ```commandline
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
[--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
...@@ -144,7 +158,6 @@ optional arguments: ...@@ -144,7 +158,6 @@ optional arguments:
evaluate at this precision evaluate at this precision
``` ```
## Comments ## Comments
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) - Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
......
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1
model:
base_learning_rate: 5.0e-07
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
depth_stage_config:
target: ldm.modules.midas.api.MiDaSInference
params:
model_type: "dpt_hybrid"
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 5
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
params:
parameterization: "v"
low_scale_key: "lr"
linear_start: 0.0001
linear_end: 0.02
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 128
channels: 4
cond_stage_trainable: false
conditioning_key: "hybrid-adm"
monitor: val/loss_simple_ema
scale_factor: 0.08333
use_ema: False
low_scale_config:
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
params:
noise_schedule_config: # image space
linear_start: 0.0001
linear_end: 0.02
max_noise_level: 350
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
image_size: 128
in_channels: 7
out_channels: 4
model_channels: 256
attention_resolutions: [ 2,4,8]
num_res_blocks: 2
channel_mult: [ 1, 2, 2, 4]
disable_self_attentions: [True, True, True, False]
disable_middle_self_attn: False
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
use_linear_in_transformer: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
ddconfig:
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
# Dataset Card for Teyvat BLIP captions
Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
## Examples
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Ganyu_001.png" title = "Ganyu_001.png" style="max-width: 20%;" >
> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Ganyu_002.png" title = "Ganyu_002.png" style="max-width: 20%;" >
> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Keqing_003.png" title = "Keqing_003.png" style="max-width: 20%;" >
> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:a anime girl with long white hair and blue eyes
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Keqing_004.png" title = "Keqing_004.png" style="max-width: 20%;" >
> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:an anime character wearing a purple dress and cat ears
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
...@@ -11,11 +12,11 @@ model: ...@@ -11,11 +12,11 @@ model:
cond_stage_key: txt cond_stage_key: txt
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false # Note: different from the one we trained before cond_stage_trainable: false
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
...@@ -26,31 +27,33 @@ model: ...@@ -26,31 +27,33 @@ model:
f_max: [ 1.e-4 ] f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ] f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8 num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 1024
use_checkpoint: False
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -69,9 +72,10 @@ model: ...@@ -69,9 +72,10 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params: params:
use_fp16: True freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
...@@ -86,37 +90,37 @@ data: ...@@ -86,37 +90,37 @@ data:
- target: torchvision.transforms.Resize - target: torchvision.transforms.Resize
params: params:
size: 512 size: 512
# - target: torchvision.transforms.RandomCrop - target: torchvision.transforms.RandomCrop
# params: params:
# size: 256 size: 512
# - target: torchvision.transforms.RandomHorizontalFlip - target: torchvision.transforms.RandomHorizontalFlip
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 2 devices: 2
log_gpu_memory: all log_gpu_memory: all
max_epochs: 10 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: lightning.pytorch.strategies.ColossalAIStrategy target: strategies.ColossalAIStrategy
params: params:
use_chunk: False use_chunk: True
enable_distributed_storage: True, enable_distributed_storage: True
placement_policy: cuda placement_policy: auto
force_outputs_fp32: False force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
default_root_dir: "/tmp/diff_log/" default_root_dir: "/tmp/diff_log/"
profiler: pytorch # profiler: pytorch
logger_config: logger_config:
wandb: wandb:
target: lightning.pytorch.loggers.WandbLogger target: loggers.WandbLogger
params: params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug
id: nowname id: nowname
\ No newline at end of file
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
log_every_t: 200 log_every_t: 200
timesteps: 1000 timesteps: 1000
first_stage_key: image first_stage_key: image
cond_stage_key: caption cond_stage_key: txt
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false # Note: different from the one we trained before cond_stage_trainable: false
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
...@@ -26,31 +27,33 @@ model: ...@@ -26,31 +27,33 @@ model:
f_max: [ 1.e-4 ] f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ] f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8 num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 1024
use_checkpoint: False
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -69,9 +72,10 @@ model: ...@@ -69,9 +72,10 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params: params:
use_fp16: True freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
...@@ -87,30 +91,30 @@ data: ...@@ -87,30 +91,30 @@ data:
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 4 devices: 1
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: lightning.pytorch.strategies.ColossalAIStrategy target: strategies.ColossalAIStrategy
params: params:
use_chunk: False use_chunk: True
enable_distributed_storage: True, enable_distributed_storage: True
placement_policy: cuda placement_policy: auto
force_outputs_fp32: False force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
default_root_dir: "/tmp/diff_log/" default_root_dir: "/tmp/diff_log/"
profiler: pytorch # profiler: pytorch
logger_config: logger_config:
wandb: wandb:
target: lightning.pytorch.loggers.WandbLogger target: loggers.WandbLogger
params: params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug
id: nowname id: nowname
\ No newline at end of file
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
...@@ -11,11 +12,11 @@ model: ...@@ -11,11 +12,11 @@ model:
cond_stage_key: txt cond_stage_key: txt
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false # Note: different from the one we trained before cond_stage_trainable: false
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
...@@ -26,31 +27,33 @@ model: ...@@ -26,31 +27,33 @@ model:
f_max: [ 1.e-4 ] f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ] f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8 num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 1024
use_checkpoint: False
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -69,9 +72,10 @@ model: ...@@ -69,9 +72,10 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params: params:
use_fp16: True freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
...@@ -94,30 +98,30 @@ data: ...@@ -94,30 +98,30 @@ data:
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 2 devices: 1
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: lightning.pytorch.strategies.ColossalAIStrategy target: strategies.ColossalAIStrategy
params: params:
use_chunk: False use_chunk: True
enable_distributed_storage: True, enable_distributed_storage: True
placement_policy: cuda placement_policy: auto
force_outputs_fp32: False force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
default_root_dir: "/tmp/diff_log/" default_root_dir: "/tmp/diff_log/"
profiler: pytorch # profiler: pytorch
logger_config: logger_config:
wandb: wandb:
target: lightning.pytorch.loggers.WandbLogger target: loggers.WandbLogger
params: params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug
id: nowname id: nowname
\ No newline at end of file
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
log_every_t: 200 log_every_t: 200
timesteps: 1000 timesteps: 1000
first_stage_key: image first_stage_key: image
cond_stage_key: caption cond_stage_key: txt
image_size: 32 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false # Note: different from the one we trained before cond_stage_trainable: false
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
params: params:
warm_up_steps: [ 100 ] warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ] f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ] f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8 num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 1024
use_checkpoint: False
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -69,32 +72,39 @@ model: ...@@ -69,32 +72,39 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params: params:
use_fp16: True freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
params: params:
batch_size: 64 batch_size: 16
wrap: False num_workers: 4
train: train:
target: ldm.data.base.Txt2ImgIterableBaseDataset target: ldm.data.teyvat.hf_dataset
params: params:
file_path: "/data/scratch/diffuser/laion_part0/" path: Fazzie/Teyvat
world_size: 1 image_transforms:
rank: 0 - target: torchvision.transforms.Resize
params:
size: 512
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 4 devices: 2
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: lightning.pytorch.strategies.DDPStrategy target: strategies.DDPStrategy
params: params:
find_unused_parameters: False find_unused_parameters: False
log_every_n_steps: 2 log_every_n_steps: 2
...@@ -105,9 +115,9 @@ lightning: ...@@ -105,9 +115,9 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: lightning.pytorch.loggers.WandbLogger target: loggers.WandbLogger
params: params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/data2/tmp/diff_log/"
offline: opt.debug offline: opt.debug
id: nowname id: nowname
\ No newline at end of file
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
log_every_t: 200 log_every_t: 200
timesteps: 1000 timesteps: 1000
first_stage_key: image first_stage_key: image
cond_stage_key: caption cond_stage_key: txt
image_size: 32 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false # Note: different from the one we trained before cond_stage_trainable: false
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False use_ema: False # we set this to false because this is an inference only config
check_nan_inf: False
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
params: params:
warm_up_steps: [ 10000 ] warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ] f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ] f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8 num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 1024
use_checkpoint: False
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -70,9 +72,10 @@ model: ...@@ -70,9 +72,10 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params: params:
use_fp16: True freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
...@@ -88,34 +91,30 @@ data: ...@@ -88,34 +91,30 @@ data:
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 4 devices: 1
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: lightning.pytorch.strategies.ColossalAIStrategy target: strategies.ColossalAIStrategy
params: params:
use_chunk: False use_chunk: True
enable_distributed_storage: True, enable_distributed_storage: True
placement_policy: cuda placement_policy: auto
force_outputs_fp32: False force_outputs_fp32: true
initial_scale: 65536
min_scale: 1
max_scale: 65536
# max_scale: 4294967296
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
default_root_dir: "/tmp/diff_log/" default_root_dir: "/tmp/diff_log/"
profiler: pytorch # profiler: pytorch
logger_config: logger_config:
wandb: wandb:
target: lightning.pytorch.loggers.WandbLogger target: loggers.WandbLogger
params: params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug
id: nowname id: nowname
\ No newline at end of file
...@@ -6,28 +6,25 @@ dependencies: ...@@ -6,28 +6,25 @@ dependencies:
- python=3.9.12 - python=3.9.12
- pip=20.3 - pip=20.3
- cudatoolkit=11.3 - cudatoolkit=11.3
- pytorch=1.11.0 - pytorch=1.12.1
- torchvision=0.12.0 - torchvision=0.13.1
- numpy=1.19.2 - numpy=1.23.1
- pip: - pip:
- albumentations==0.4.3 - albumentations==1.3.0
- datasets
- diffusers
- opencv-python==4.6.0.66 - opencv-python==4.6.0.66
- pudb==2019.2
- invisible-watermark
- imageio==2.9.0 - imageio==2.9.0
- imageio-ffmpeg==0.4.2 - imageio-ffmpeg==0.4.2
- lightning==1.8.1
- omegaconf==2.1.1 - omegaconf==2.1.1
- test-tube>=0.7.5 - test-tube>=0.7.5
- streamlit>=0.73.1 - streamlit==1.12.1
- einops==0.3.0 - einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2 - transformers==4.19.2
- torchmetrics==0.7.0 - webdataset==0.2.5
- kornia==0.6 - kornia==0.6
- open_clip_torch==2.0.2
- invisible-watermark>=0.1.5
- streamlit-drawable-canvas==0.8.0
- torchmetrics==0.7.0
- prefetch_generator - prefetch_generator
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - datasets
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e . - -e .
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
extract_into_tensor
class DDIMSampler(object): class DDIMSampler(object):
...@@ -74,15 +72,24 @@ class DDIMSampler(object): ...@@ -74,15 +72,24 @@ class DDIMSampler(object):
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.,
unconditional_conditioning=None, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... dynamic_threshold=None,
ucg_schedule=None,
**kwargs **kwargs
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
...@@ -107,6 +114,8 @@ class DDIMSampler(object): ...@@ -107,6 +114,8 @@ class DDIMSampler(object):
log_every_t=log_every_t, log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
) )
return samples, intermediates return samples, intermediates
...@@ -116,7 +125,8 @@ class DDIMSampler(object): ...@@ -116,7 +125,8 @@ class DDIMSampler(object):
callback=None, timesteps=None, quantize_denoised=False, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
...@@ -145,12 +155,18 @@ class DDIMSampler(object): ...@@ -145,12 +155,18 @@ class DDIMSampler(object):
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature, quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector, noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
...@@ -164,20 +180,44 @@ class DDIMSampler(object): ...@@ -164,20 +180,44 @@ class DDIMSampler(object):
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None): unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c) model_output = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) if isinstance(c, dict):
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) assert isinstance(unconditional_conditioning, dict)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == "eps" assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
...@@ -191,9 +231,17 @@ class DDIMSampler(object): ...@@ -191,9 +231,17 @@ class DDIMSampler(object):
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
...@@ -202,6 +250,53 @@ class DDIMSampler(object): ...@@ -202,6 +250,53 @@ class DDIMSampler(object):
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction # fast, but does not allow for exact reconstruction
...@@ -220,7 +315,7 @@ class DDIMSampler(object): ...@@ -220,7 +315,7 @@ class DDIMSampler(object):
@torch.no_grad() @torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False): use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
...@@ -237,4 +332,5 @@ class DDIMSampler(object): ...@@ -237,4 +332,5 @@ class DDIMSampler(object):
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec return x_dec
\ No newline at end of file
from .sampler import DPMSolverSampler
\ No newline at end of file
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
return x.to(device), None
\ No newline at end of file
...@@ -6,6 +6,7 @@ from tqdm import tqdm ...@@ -6,6 +6,7 @@ from tqdm import tqdm
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object): class PLMSSampler(object):
...@@ -77,6 +78,7 @@ class PLMSSampler(object): ...@@ -77,6 +78,7 @@ class PLMSSampler(object):
unconditional_guidance_scale=1., unconditional_guidance_scale=1.,
unconditional_conditioning=None, unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs **kwargs
): ):
if conditioning is not None: if conditioning is not None:
...@@ -108,6 +110,7 @@ class PLMSSampler(object): ...@@ -108,6 +110,7 @@ class PLMSSampler(object):
log_every_t=log_every_t, log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
) )
return samples, intermediates return samples, intermediates
...@@ -117,7 +120,8 @@ class PLMSSampler(object): ...@@ -117,7 +120,8 @@ class PLMSSampler(object):
callback=None, timesteps=None, quantize_denoised=False, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
...@@ -155,7 +159,8 @@ class PLMSSampler(object): ...@@ -155,7 +159,8 @@ class PLMSSampler(object):
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next) old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
...@@ -172,7 +177,8 @@ class PLMSSampler(object): ...@@ -172,7 +177,8 @@ class PLMSSampler(object):
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
...@@ -207,6 +213,8 @@ class PLMSSampler(object): ...@@ -207,6 +213,8 @@ class PLMSSampler(object):
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment