Commit 5a92374a authored by suily's avatar suily
Browse files

test commit

parents
# SAT CogView3 & CogView-3-Plus
[Read this in Chinese](./README_zh.md)
This folder contains the inference code using the [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, as well as fine-tuning code for SAT weights.
The code is the framework used by the team during model training. There are few comments, so it requires careful study.
## Step-by-step guide to running the model
### 1. Environment setup
Ensure you have installed the dependencies required by this folder:
```shell
pip install -r requirements.txt
```
### 2. Download model weights
The following links are for different model weights:
### CogView-3-Plus-3B
+ transformer: https://cloud.tsinghua.edu.cn/d/f913eabd3f3b4e28857c
+ vae: https://cloud.tsinghua.edu.cn/d/af4cc066ce8a4cf2ab79
### CogView-3-Base-3B
+ transformer:
+ cogview3-base: https://cloud.tsinghua.edu.cn/d/242b66daf4424fa99bf0
+ cogview3-base-distill-4step: https://cloud.tsinghua.edu.cn/d/d10032a94db647f5aa0e
+ cogview3-base-distill-8step: https://cloud.tsinghua.edu.cn/d/1598d4fe4ebf4afcb6ae
**These three versions are interchangeable. Choose the one that suits your needs and run it with the corresponding configuration file.**
+ vae: https://cloud.tsinghua.edu.cn/d/c8b9497fc5124d71818a/
### CogView-3-Base-3B-Relay
+ transformer:
+ cogview3-relay: https://cloud.tsinghua.edu.cn/d/134951acced949c1a9e1/
+ cogview3-relay-distill-2step: https://cloud.tsinghua.edu.cn/d/6a902976fcb94ac48402
+ cogview3-relay-distill-1step: https://cloud.tsinghua.edu.cn/d/4d50ec092c64418f8418/
**These three versions are interchangeable. Choose the one that suits your needs and run it with the corresponding configuration file.**
+ vae: Same as CogView-3-Base-3B
Next, arrange the model files into the following format:
```
.cogview3-plus-3b
├── transformer
│ ├── 1
│ │ └── mp_rank_00_model_states.pt
│ └── latest
└── vae
└── imagekl_ch16.pt
```
Clone the T5 model. This model is not used for training or fine-tuning but is necessary. You can download the T5 model separately, but it must be in `safetensors` format, not `bin` format (otherwise an error may occur).
Since we have uploaded the T5 model in `safetensors` format in `CogVideoX`, a simple way is to clone the model from the `CogVideoX-2B` model and move it to the corresponding folder.
```shell
git clone https://huggingface.co/THUDM/CogVideoX-2b.git
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
With this setup, you will have a safetensor format T5 file, ensuring no errors during Deepspeed fine-tuning.
```
├── added_tokens.json
├── config.json
├── model-00001-of-00002.safetensors
├── model-00002-of-00002.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── spiece.model
└── tokenizer_config.json
0 directories, 8 files
```
### 3. Modify the files in `configs`.
Here is an example using `CogView3-Base`, with explanations for some of the parameters:
```yaml
args:
mode: inference
relay_model: False # Set to True when using CogView-3-Relay
load: "cogview3_base/transformer" # Path to the transformer folder
batch_size: 8 # Number of images per inference
grid_num_columns: 2 # Number of columns in grid.png output
input_type: txt # Input can be from command line or TXT file
input_file: configs/test.txt # Not needed for command line input
fp16: True # Set to bf16 for CogView-3-Plus inference
# bf16: True
sampling_image_size: 512 # Fixed size, supports 512x512 resolution images
# For CogView-3-Plus, use the following:
# sampling_image_size_x: 1024 (width)
# sampling_image_size_y: 1024 (height)
output_dir: "outputs/cogview3_base-512x512"
# This section is for CogView-3-Relay. Set the input_dir to the folder with base model generated images.
# input_dir: "outputs/cogview3_base-512x512"
deepspeed_config: { }
model:
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl" # Path to T5 safetensors
max_length: 225 # Maximum prompt length
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
ckpt_path: "cogview3_base/vae/imagekl_ch16.pt" # Path to VAE PT file
monitor: val/rec_loss
```
### 4. Running the model
Different models require different code for inference. Here are the inference commands for each model:
### CogView-3Plus
```shell
python sample_dit.py --base configs/cogview3_plus.yaml
```
### CogView-3-Base
+ Original model
```shell
python sample_unet.py --base configs/cogview3_base.yaml
```
+ Distilled model
```bash
python sample_unet.py --base configs/cogview3_base_distill_4step.yaml
```
### CogView-3-Relay
+ Original model
```shell
python sample_unet.py --base configs/cogview3_relay.yaml
```
+ Distilled model
```shell
python sample_unet.py --base configs/cogview3_relay_distill_1step.yaml
```
The output image format will be a folder. The folder name will consist of the sequence number and the first 15 characters of the prompt, containing multiple images. The number of images is based on the `batch` parameter. The structure should look like this:
```
.
├── 000000000.png
├── 000000001.png
├── 000000002.png
├── 000000003.png
├── 000000004.png
├── 000000005.png
├── 000000006.png
├── 000000007.png
└── grid.png
1 directory, 9 files
```
In this example, the `batch` size is 8, so there are 8 images along with one `grid.png`.
# SAT CogView3 && CogView-3-Plus
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
## 手把手带你运行模型
### 1. 环境安装
确保你已经正确安装本文件夹中的要求的依赖
```shell
pip install -r requirements.txt
```
### 2. 下载模型权重
以下链接为各个模型权重:
### CogView-3-Plus-3B
+ transformer: https://cloud.tsinghua.edu.cn/d/f913eabd3f3b4e28857c
+ vae: https://cloud.tsinghua.edu.cn/d/af4cc066ce8a4cf2ab79
### CogView-3-Base-3B
+ transformer:
+ cogview3-base: https://cloud.tsinghua.edu.cn/d/242b66daf4424fa99bf0
+ cogview3-base-distill-4step: https://cloud.tsinghua.edu.cn/d/d10032a94db647f5aa0e
+ cogview3-base-distill-8step: https://cloud.tsinghua.edu.cn/d/1598d4fe4ebf4afcb6ae
+
**以上三个版本为替换关系,选择适合自己的版本和对应的配置文件进行运行**
+ vae: https://cloud.tsinghua.edu.cn/d/c8b9497fc5124d71818a/
### CogView-3-Base-3B-Relay
+ transformer:
+ cogview3-relay: https://cloud.tsinghua.edu.cn/d/134951acced949c1a9e1/
+ cogview3-relay-distill-2step: https://cloud.tsinghua.edu.cn/d/6a902976fcb94ac48402
+ cogview3-relay-distill-1step: https://cloud.tsinghua.edu.cn/d/4d50ec092c64418f8418/
**以上三个版本为替换关系,选择适合自己的版本和对应的配置文件进行运行**
+ vae: 与 CogView-3-Base-3B 相同
接着,你需要将模型文件排版成如下格式:
```
.cogview3-plus-3b
├── transformer
│ ├── 1
│ │ └── mp_rank_00_model_states.pt
│ └── latest
└── vae
└── imagekl_ch16.pt
```
克隆 T5 模型,该模型不用做训练和微调,但是必须使用。这里,您可以单独下载T5模型,必须是`safetensors`类型,不能是`bin`
类型(否则可能出现错误)。
由于我们在`CogVideoX`中上传过 `safetensors` 格式的T5模型,一个简单的办法是从`CogVideX-2B`模型中克隆模型,然后将其移动到对应的文件夹中。
```shell
git clone https://huggingface.co/THUDM/CogVideoX-2b.git #从huggingface下载模型
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #从modelscope下载模型
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
通过上述方案,你将会得到一个 safetensor 格式的T5文件,确保在 Deepspeed微调过程中读入的时候不会报错。
```
├── added_tokens.json
├── config.json
├── model-00001-of-00002.safetensors
├── model-00002-of-00002.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── spiece.model
└── tokenizer_config.json
0 directories, 8 files
```
### 3. 修改`configs`中的文件。
这里以`CogView3-Base`为例,提供部分参数的讲解和介绍:
```yaml
args:
mode: inference
relay_model: False # 当模型类型为 CogView-3-Relay 时,需要将该参数设置为 True
load: "cogview3_base/transformer" # 这里填写到transformer文件夹
batch_size: 8 # 每次推理图像数
grid_num_columns: 2 # 推理结束后,每个提示词文件夹下会有 grid.png 图片,该数字代表列数。
input_type: txt # 可以选择命令行输入,或者TXT文件输入
input_file: configs/test.txt # 如果使用命令行,不需要这个参数
fp16: True # CogView-3-Plus 模型 需要更换为 bf16 推理
# bf16: True
sampling_image_size: 512 # 固定大小,支持512 * 512 分辨率图像
# CogView-3-Plus 模型可以使用以下两个参数。
# sampling_image_size_x: 1024 宽
# sampling_image_size_y: 1024 高
output_dir: "outputs/cogview3_base-512x512"
# # 这个部分是给 CogView-3-Relay 模型使用的,需要将该参数设置为推理模型的输入文件夹,提示词建议与 base 模型生成图片时的提示词的一致。
# input_dir: "outputs/cogview3_base-512x512"
deepspeed_config: { }
model:
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl" # T5 safetensors的绝对路径
max_length: 225 # 支持输入的提示词的最大长度
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
ckpt_path: "cogview3_base/vae/imagekl_ch16.pt" # VAE PT文件绝对路径
monitor: val/rec_loss
```
### 4. 推理模型
由于不同的模型需要使用的代码不一样,在这里,我们列出了不同模型的推理代码:
### CogView-3Plus
```shell
python sample_dit.py --base configs/cogview3_plus.yaml
```
### CogView-3-Base
+ 原始模型
```shell
python sample_unet.py --base configs/cogview3_base.yaml
```
+ 蒸馏版本模型
```bash
python sample_unet.py --base configs/cogview3_base_distill_4step.yaml
```
### CogView-3-Relay
+ 原始模型
```shell
python sample_unet.py --base configs/cogview3_relay.yaml
```
+ 蒸馏版本模型
```shell
python sample_unet.py --base configs/cogview3_relay_distill_1step.yaml
```
输出图片格式为文件夹,其中,文件夹的名字为生成的序号加提示词的前15个字母,文件夹中包含多张图片,具体数量以 `batch` 参数为准。
其结构应该如下:
```
.
├── 000000000.png
├── 000000001.png
├── 000000002.png
├── 000000003.png
├── 000000004.png
├── 000000005.png
├── 000000006.png
├── 000000007.png
└── grid.png
1 directory, 9 files
```
上述例子中,`batch` 为8。因此,有8张图像并带有一张`grid.png`的图像。
\ No newline at end of file
import argparse
import os
import torch
import json
import warnings
import omegaconf
from omegaconf import OmegaConf
from sat.helpers import print_rank0
from sat import mpu
from sat.arguments import set_random_seed
from sat.arguments import add_training_args, add_evaluation_args, add_data_args
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
group.add_argument(
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
)
group.add_argument("--force-pretrain", action="store_true")
group.add_argument("--device", type=int, default=-1)
return parser
def add_sampling_config_args(parser):
"""Sampling configurations"""
group = parser.add_argument_group("sampling", "Sampling Configurations")
group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--output-dir", type=str, default="samples")
group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--relay-model", type=bool, default=False)
group.add_argument("--input-file", type=str, default="input.txt")
group.add_argument("--sampling-image-size", type=int, default=1024)
group.add_argument("--sampling-latent-dim", type=int, default=4)
group.add_argument("--sampling-f", type=int, default=8)
group.add_argument("--sampling-image-size-x", type=int, default=None)
group.add_argument("--sampling-image-size-y", type=int, default=None)
group.add_argument("--sdedit", action="store_true")
group.add_argument("--ip2p", action="store_true")
group.add_argument("--grid-num-columns", type=int, default=1)
group.add_argument("--force-inference", action="store_true")
return parser
def add_additional_config_args(parser):
group = parser.add_argument_group("additional", "Additional Configurations")
group.add_argument("--multiaspect-training", action="store_true")
group.add_argument("--multiaspect-shapes", nargs="+", default=None, type=int)
return parser
def get_args(args_list=None, parser=None):
"""Parse all the args."""
if parser is None:
parser = argparse.ArgumentParser(description="sat")
else:
assert isinstance(parser, argparse.ArgumentParser)
parser = add_model_config_args(parser)
parser = add_sampling_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_data_args(parser)
parser = add_additional_config_args(parser)
# Include DeepSpeed configuration arguments
import deepspeed
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(args_list)
args = process_config_to_args(args)
if not args.train_data:
print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv("RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
if args.local_rank is None:
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
if args.device == -1: # not set manually
if torch.cuda.device_count() == 0:
args.device = "cpu"
elif args.local_rank is not None:
args.device = args.local_rank
else:
args.device = args.rank % torch.cuda.device_count()
# local rank should be consistent with device in DeepSpeed
if args.local_rank != args.device and args.mode != "inference":
raise ValueError(
"LOCAL_RANK (default 0) and args.device inconsistent. "
"This can only happens in inference mode. "
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
)
# args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
print_rank0("using world size: {}".format(args.world_size))
# if args.vocab_size > 0:
# _adjust_vocab_size(args)
if args.train_data_weights is not None:
assert len(args.train_data_weights) == len(args.train_data)
if args.mode != "inference": # training with deepspeed
args.deepspeed = True
if args.deepspeed_config is None: # not specified
deepspeed_config_path = os.path.join(
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
)
with open(deepspeed_config_path) as file:
args.deepspeed_config = json.load(file)
override_deepspeed_config = True
else:
override_deepspeed_config = False
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
print_rank0("Automatically set fp16=True to use ZeRO.")
args.fp16 = True
args.bf16 = False
if args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
else:
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
deepspeed_config = args.deepspeed_config
# with open(args.deepspeed_config) as file:
# deepspeed_config = json.load(file)
if override_deepspeed_config: # not specify deepspeed_config, use args
if args.fp16:
deepspeed_config["fp16"]["enabled"] = True
elif args.bf16:
deepspeed_config["bf16"]["enabled"] = True
deepspeed_config["fp16"]["enabled"] = False
else:
deepspeed_config["fp16"]["enabled"] = False
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
optimizer_params_config = deepspeed_config["optimizer"]["params"]
optimizer_params_config["lr"] = args.lr
optimizer_params_config["weight_decay"] = args.weight_decay
else: # override args with values in deepspeed_config
if args.rank == 0:
print_rank0("Will override arguments with manually specified deepspeed_config!")
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
args.fp16 = True
else:
args.fp16 = False
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
args.bf16 = True
else:
args.bf16 = False
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
args.deepspeed_config = deepspeed_config
# if args.sandwich_ln: # removed in v0.3
# args.layernorm_order = 'sandwich'
# initialize distributed and random seed because it always seems to be necessary.
initialize_distributed(args)
args.seed = args.seed + torch.distributed.get_rank()
set_random_seed(args.seed)
return args
def initialize_distributed(args):
"""Initialize torch.distributed."""
if torch.distributed.is_initialized():
if mpu.model_parallel_is_initialized():
if args.model_parallel_size != mpu.get_model_parallel_world_size():
raise ValueError(
"model_parallel_size is inconsistent with prior configuration."
"We currently do not support changing model_parallel_size."
)
return False
else:
if args.model_parallel_size > 1:
warnings.warn(
"model_parallel_size > 1 but torch.distributed is not initialized via SAT."
"Please carefully make sure the correctness on your own."
)
mpu.initialize_model_parallel(args.model_parallel_size)
return True
# the automatic assignment of devices has been moved to arguments.py
if args.device == "cpu":
pass
else:
torch.cuda.set_device(args.device)
# Call the init process
init_method = "tcp://"
args.master_ip = os.getenv("MASTER_ADDR", "localhost")
if args.world_size == 1:
from sat.helpers import get_free_port
default_master_port = str(get_free_port())
else:
default_master_port = "6000"
args.master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
)
# Set the model-parallel / data-parallel communicators.
# mpu.initialize_model_parallel(args.model_parallel_size)
mpu.initialize_model_parallel(1)
# Optional DeepSpeed Activation Checkpointing Features
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
else:
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import (
_CUDA_RNG_STATE_TRACKER,
_MODEL_PARALLEL_RNG_TRACKER_NAME,
)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
except Exception as e:
from sat.helpers import print_rank0
print_rank0(str(e), level="DEBUG")
return True
def process_config_to_args(args):
"""Fetch args from only --base"""
configs = [OmegaConf.load(cfg) for cfg in args.base]
config = OmegaConf.merge(*configs)
args_config = config.pop("args", OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]
if hasattr(args, key):
setattr(args, key, arg)
if "model" in config:
model_config = config.pop("model", OmegaConf.create())
args.model_config = model_config
if "deepspeed" in config:
deepspeed_config = config.pop("deepspeed", OmegaConf.create())
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
if "data" in config:
data_config = config.pop("data", OmegaConf.create())
args.data_config = data_config
return args
args:
mode: inference
relay_model: False
load: "/home/models/CogView4/CogView3/cogview3-base/transformer"
batch_size: 4
grid_num_columns: 2
input_type: txt
input_file: "configs/test.txt"
fp16: True
force_inference: True
sampling_image_size: 512
output_dir: "outputs/cogview3_base-512x512"
deepspeed_config: { }
model:
scale_factor: 0.13025
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 1536
num_classes: sequential
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 4096
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "/home/models/CogView4/t5-v1_1-xxl"
max_length: 225
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: "/home/models/CogView4/CogView3/cogview3-base/vae/sdxl_vae.safetensors"
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 10
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 7.5
args:
mode: inference
relay_model: False
load: "/home/models/CogView4/CogView3/cogview3-base/transformer_distill_4step"
batch_size: 4
grid_num_columns: 2
input_type: txt
input_file: "configs/test.txt"
fp16: True
force_inference: True
sampling_image_size: 512
output_dir: "outputs/cogview3_base_distill-4step"
deepspeed_config: {}
model:
scale_factor: 0.13025
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 1536
num_classes: sequential
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 4096
spatial_transformer_attn_type: softmax-xformers
legacy: False
cfg_cond_embed_dim: 512
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "/home/models/CogView4/t5-v1_1-xxl"
max_length: 225
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: "/home/models/CogView4/CogView3/cogview3-base/vae/sdxl_vae.safetensors"
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
cfg_cond_scale: 7.5
cfg_cond_embed_dim: 512
num_steps: 4
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.IdentityGuider
args:
mode: inference
relay_model: False
load: "/home/models/CogView4/CogView3/cogview3-plus-3b/transformer"
batch_size: 4
grid_num_columns: 2
input_type: txt
input_file: "configs/test.txt"
bf16: True
force_inference: True
sampling_image_size_x: 512
sampling_image_size_y: 512
sampling_latent_dim: 16
output_dir: "outputs/cogview3_plus"
deepspeed_config: { }
model:
scale_factor: 1
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.ZeroSNRScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 4
network_config:
target: sgm.modules.diffusionmodules.dit.DiffusionTransformer
params:
in_channels: 16
out_channels: 16
hidden_size: 2560
num_layers: 30
patch_size: 2
block_size: 16
num_attention_heads: 64
text_length: 224
time_embed_dim: 512
num_classes: sequential
adm_in_channels: 1536
modules:
pos_embed_config:
target: sgm.modules.diffusionmodules.dit.PositionEmbeddingMixin
params:
max_height: 128
max_width: 128
max_length: 4096
patch_embed_config:
target: sgm.modules.diffusionmodules.dit.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
attention_config:
target: sgm.modules.diffusionmodules.dit.AdalnAttentionMixin
params:
qk_ln: true
final_layer_config:
target: sgm.modules.diffusionmodules.dit.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "/home/models/CogView4/t5-v1_1-xxl"
max_length: 224
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
ckpt_path: "/home/models/CogView4/CogView3/cogview3-plus-3b/vae/imagekl_ch16.pt"
monitor: val/rec_loss
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 4, 8, 8 ]
num_res_blocks: 3
attn_resolutions: [ ]
mid_attn: False
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: vanilla
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 4, 8, 8 ]
num_res_blocks: 3
attn_resolutions: [ ]
mid_attn: False
dropout: 0.0
loss_fn_config:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.ZeroSNRDDIMSampler
params:
num_steps: 10
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 4
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 5
args:
mode: inference
relay_model: True
load: "/home/models/CogView4/CogView3/cogview3-relay/transformer"
batch_size: 4
grid_num_columns: 2
input_type: txt
input_file: "configs/test.txt"
fp16: True
force_inference: True
sampling_image_size: 1024
output_dir: "outputs/cogview3_relay-1024x1024"
input_dir: "outputs/cogview3_base-512x512"
deepspeed_config: { }
model:
scale_factor: 0.13025
disable_first_stage_autocast: true
lr_scale: 2
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 1536
num_classes: sequential
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
# note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
transformer_depth: [ 1, 2, 10 ]
context_dim: 4096
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "/home/models/CogView4/t5-v1_1-xxl"
max_length: 225
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: "/home/models/CogView4/CogView3/cogview3-relay/vae/sdxl_vae.safetensors"
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.LinearRelayDiffusionLoss
params:
offset_noise_level: 0.05
partial_num_steps: 500
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.LinearRelayEDMSampler
params:
# Suggestion config
partial_num_steps: 12
num_steps: 24
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 7.5
args:
mode: inference
load: "transformer"
batch_size: 4
grid_num_columns: 2
input_type: txt
input_file: configs/test.txt
fp16: True
force_inference: True
sampling_image_size: 1024 # 这个值应该是你输入的图像分辨率的两倍
output_dir: "outputs/cogview3_relay_distill_1step"
input_dir: "inputs" # the inputs image should follow the order of input_file or cli input
deepspeed_config: { }
model:
scale_factor: 0.13025
disable_first_stage_autocast: true
lr_scale: 2
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 1536
num_classes: sequential
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
# note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
transformer_depth: [ 1, 2, 10 ]
context_dim: 4096
spatial_transformer_attn_type: softmax-xformers
legacy: False
cfg_cond_embed_dim: 256 # This is different from original one.
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 225
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: "vae/sdxl_vae.safetensors"
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.LinearRelayEDMSampler
params:
# Suggestion config, big different with original one.
cfg_cond_scale: 7.5
cfg_cond_embed_dim: 256
partial_num_steps: 1
num_steps: 5
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.IdentityGuider
import math
from typing import Any, Dict, List, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import default, get_obj_from_str, instantiate_from_config
class SATDiffusionEngine(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
# model args preprocess
log_keys = model_config.get("log_keys", None)
input_key = model_config.get("input_key", "jpg")
network_config = model_config.get("network_config", None)
network_wrapper = model_config.get("network_wrapper", None)
denoiser_config = model_config.get("denoiser_config", None)
sampler_config = model_config.get("sampler_config", None)
conditioner_config = model_config.get("conditioner_config", None)
first_stage_config = model_config.get("first_stage_config", None)
loss_fn_config = model_config.get("loss_fn_config", None)
scale_factor = model_config.get("scale_factor", 1.0)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
untrainable_prefixs = model_config.get("untrainable_prefixs", ["first_stage_model", "conditioner"])
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
use_pd = model_config.get("use_pd", False) # progressive distillation
self.log_keys = log_keys
self.input_key = input_key
self.untrainable_prefixs = untrainable_prefixs
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.use_pd = use_pd
if args.fp16:
dtype = torch.float16
dtype_str = "fp16"
elif args.bf16:
dtype = torch.bfloat16
dtype_str = "bf16"
else:
dtype = torch.float32
dtype_str = "fp32"
self.dtype = dtype
self.dtype_str = dtype_str
network_config["params"]["dtype"] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model, dtype=dtype
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
first_stage_model = instantiate_from_config(first_stage_config).eval()
for param in first_stage_model.parameters():
param.requires_grad = False
self.first_stage_model = first_stage_model
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
def forward(self, x, batch, **kwargs):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
lr_z = self.encode_first_stage(lr_x)
batch["lr_input"] = lr_z
x = self.encode_first_stage(x)
# batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
target_size=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
if target_size is not None:
denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
self.model, input, sigma, c, target_size=target_size, **additional_model_inputs
)
else:
denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
self.model, input, sigma, c, **additional_model_inputs
)
samples = self.sampler(denoiser, randn, cond, uc=uc)
if isinstance(samples, list):
for i in range(len(samples)):
samples[i] = samples[i].to(self.dtype)
else:
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def sample_relay(
self,
image: torch.Tensor,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.dtype).to(self.device)
denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
self.model, input, sigma, c, **additional_model_inputs
)
samples = self.sampler(denoiser, image, randn, cond, uc=uc)
if isinstance(samples, list):
for i in range(len(samples)):
samples[i] = samples[i].to(self.dtype)
else:
samples = samples.to(self.dtype)
return samples
deepspeed>=0.15.1
transformers>=4.45.0
xformers>=0.0.28
torch>=2.4.0
pytorch_lightning>=2.4.0
torchvision>=0.19.0
einops>=0.8.0
fsspec>=2024.6.1
kornia>=0.7.3
numpy>=2.1.1
omegaconf>=2.3.0
open_clip_torch>=2.26.1
Pillow>=10.4.0
safetensors>=0.4.5
scipy>=1.14.1
SwissArmyTransformer>=0.4.12
tqdm>=4.66.5
wandb>=0.18.1
openai>=1.48.0
sentencepiece>=0.2.0
\ No newline at end of file
import os
import math
import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
from PIL import Image
import torch
import numpy as np
from einops import rearrange, repeat
from torchvision.utils import make_grid
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion import SATDiffusionEngine
from arguments import get_args
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
elif key == "cfg_scale":
batch[key] = torch.tensor([value_dict["cfg_scale"]]).to(device).repeat(math.prod(N))
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def perform_save_locally(save_path, samples, grid, only_save_grid=False):
os.makedirs(save_path, exist_ok=True)
if not only_save_grid:
for i, sample in enumerate(samples):
sample = 255.0 * rearrange(sample.numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(os.path.join(save_path, f"{i:09}.png"))
if grid is not None:
grid = 255.0 * rearrange(grid.numpy(), "c h w -> h w c")
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(save_path, f"grid.png"))
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size_x = args.sampling_image_size_x
image_size_y = args.sampling_image_size_y
image_size = (image_size_x, image_size_y)
latent_dim = args.sampling_latent_dim
f = args.sampling_f
assert (
image_size_x >= 512 and image_size_y >= 512 and image_size_x <= 2048 and image_size_y <= 2048
), "Image size should be between 512 and 2048"
assert image_size_x % 32 == 0 and image_size_y % 32 == 0, "Image size should be divisible by 32"
sample_func = model.sample
H, W, C, F = image_size_x, image_size_y, latent_dim, f
num_samples = [args.batch_size]
force_uc_zero_embeddings = ["txt"]
with torch.no_grad():
for text, cnt in tqdm(data_iter):
value_dict = {
"prompt": text,
"negative_prompt": "",
"original_size_as_tuple": image_size,
"target_size_as_tuple": image_size,
"orig_height": image_size_x,
"orig_width": image_size_y,
"target_height": image_size_x,
"target_width": image_size_y,
"crop_coords_top": 0,
"crop_coords_left": 0,
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
samples_z = sample_func(
c,
uc=uc,
batch_size=args.batch_size,
shape=(C, H // F, W // F),
target_size=[image_size],
)
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
batch_size = samples.shape[0]
assert (batch_size // args.grid_num_columns) * args.grid_num_columns == batch_size
if args.batch_size == 1:
grid = None
else:
grid = make_grid(samples, nrow=args.grid_num_columns)
save_path = os.path.join(args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:20])
perform_save_locally(save_path, samples, grid)
if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
sampling_main(args, model_cls=SATDiffusionEngine)
import os
import math
import argparse
from tqdm import tqdm
from typing import List, Union
from omegaconf import ListConfig
from PIL import Image
import torch
import torch.nn.functional as functional
import numpy as np
from einops import rearrange, repeat
from torchvision.utils import make_grid
import torchvision.transforms as TT
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion import SATDiffusionEngine
from arguments import get_args
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
elif key == "cfg_scale":
batch[key] = torch.tensor([value_dict["cfg_scale"]]).to(device).repeat(math.prod(N))
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def perform_save_locally(save_path, samples, grid, only_save_grid=False):
os.makedirs(save_path, exist_ok=True)
if not only_save_grid:
for i, sample in enumerate(samples):
sample = 255.0 * rearrange(sample.numpy(), "c h w -> h w c")
print(f"1111111111111111111111111111 Max: {sample.max()}, Min: {sample.min()}")
Image.fromarray(sample.astype(np.uint8)).save(os.path.join(save_path, f"{i:09}.png"))
if grid is not None:
grid = 255.0 * rearrange(grid.numpy(), "c h w -> h w c")
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(save_path, f"grid.png"))
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = args.sampling_image_size
input_sample_dirs = None
if args.relay_model is True:
sample_func = model.sample_relay
H, W, C, F = image_size, image_size, 4, 8
assert args.input_dir is not None
input_sample_dirs = os.listdir(args.input_dir)
input_sample_dirs_and_rank = sorted([(int(name.split("_")[0]), name) for name in input_sample_dirs])
input_sample_dirs = [os.path.join(args.input_dir, name) for _, name in input_sample_dirs_and_rank]
else:
sample_func = model.sample
latent_dim = args.sampling_latent_dim
f = args.sampling_f
H, W, C, F = image_size, image_size, latent_dim, f
num_samples = [args.batch_size]
force_uc_zero_embeddings = ["txt"]
with torch.no_grad():
for text, cnt in tqdm(data_iter):
value_dict = {
"prompt": text,
"negative_prompt": "",
"original_size_as_tuple": (image_size, image_size),
"target_size_as_tuple": (image_size, image_size),
"orig_height": image_size,
"orig_width": image_size,
"target_height": image_size,
"target_width": image_size,
"crop_coords_top": 0,
"crop_coords_left": 0,
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.relay_model is True:
input_sample_dir = input_sample_dirs[cnt]
images = []
for i in range(args.batch_size):
filepath = os.path.join(input_sample_dir, f"{i:09}.png")
image = Image.open(filepath).convert("RGB")
image = TT.ToTensor()(image) * 2 - 1
images.append(image[None, ...])
images = torch.cat(images, dim=0)
images = functional.interpolate(images, scale_factor=2, mode="bilinear", align_corners=False)
images = images.to(torch.float16).cuda()
images = model.encode_first_stage(images)
samples_z = sample_func(images, c, uc=uc, batch_size=args.batch_size, shape=(C, H // F, W // F))
else:
samples_z = sample_func(c, uc=uc, batch_size=args.batch_size, shape=(C, H // F, W // F))
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
batch_size = samples.shape[0]
assert (batch_size // args.grid_num_columns) * args.grid_num_columns == batch_size
if args.batch_size == 1:
grid = None
else:
grid = make_grid(samples, nrow=args.grid_num_columns)
save_path = os.path.join(args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:20])
perform_save_locally(save_path, samples, grid)
if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
sampling_main(args, model_cls=SATDiffusionEngine)
from .models import AutoencodingEngine
from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.0"
from .autoencoder import AutoencodingEngine
#XuDwndGaCFo
import logging
import math
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig
from packaging import version
from safetensors.torch import load_file as load_safetensors
from ..modules.diffusionmodules.model import Decoder, Encoder
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list, ListConfig] = (),
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def init_from_ckpt(
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
# print(
# f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
# )
# if len(missing) > 0:
# print(f"Missing Keys: {missing}")
# if len(unexpected) > 0:
# print(f"Unexpected Keys: {unexpected}")
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
if isinstance(ckpt, str):
ckpt = {
"target": "sgm.modules.checkpoint.CheckpointEngine",
"params": {"ckpt_path": ckpt},
}
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
ckpt_path=None,
ignore_keys=[],
**kwargs,
):
super().__init__(*args, **kwargs)
# todo: add options to freeze encoder/decoder
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"}
)
self.lr_g_factor = lr_g_factor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
if path.endswith("ckpt") or path.endswith("pt"):
sd = torch.load(path, map_location="cpu")['state_dict']
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.regularization.get_trainable_parameters())
+ list(self.loss.get_trainable_autoencoder_parameters())
)
return params
def get_discriminator_params(self) -> list:
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z)
return z, dec, reg_log
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
def validation_step(self, batch, batch_idx) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
log_dict_ae.update(log_dict_disc)
self.log_dict(log_dict_ae)
return log_dict_ae
def configure_optimizers(self) -> Any:
ae_params = self.get_autoencoder_params()
disc_params = self.get_discriminator_params()
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
return [opt_ae, opt_disc], []
@torch.no_grad()
def log_images(self, batch: Dict, **kwargs) -> Dict:
log = dict()
x = self.get_input(batch)
_, xrec, _ = self(x)
log["inputs"] = x
log["reconstructions"] = xrec
with self.ema_scope():
_, xrec_ema, _ = self(x)
log["reconstructions_ema"] = xrec_ema
return log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__(
encoder_config={
"target": "sgm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", ())
super().__init__(
encoder_config={"target": "torch.nn.Identity"},
decoder_config={"target": "torch.nn.Identity"},
regularizer_config={"target": "torch.nn.Identity"},
loss_config=kwargs.pop("lossconfig"),
**kwargs,
)
assert ddconfig["double_z"]
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def encode(self, x):
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **decoder_kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z, **decoder_kwargs)
return dec
class AutoencoderKLInferenceWrapper(AutoencoderKL):
def encode(self, x):
return super().encode(x).sample()
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
),
"params": {"sample": False},
},
**kwargs,
)
from .encoders.modules import GeneralConditioner
UNCONDITIONAL_CONFIG = {
"target": "sgm.modules.GeneralConditioner",
"params": {"emb_models": []},
}
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
v = repeat(
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
# print(
# f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
# f"{heads} heads with a dimension of {dim_head}."
# )
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode="softmax",
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
# if self.checkpoint:
# print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
kwargs = {"x": x}
if context is not None:
kwargs.update({"context": context})
if additional_tokens is not None:
kwargs.update({"additional_tokens": additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update(
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
)
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens
)
+ x
)
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode="softmax",
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type="softmax",
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
# print(
# f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
# )
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
# print(
# f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
# f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
# )
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
from typing import Any, Union
import torch
import torch.nn as nn
from einops import rearrange
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
pixelloss_weight=1.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.learn_logvar = learn_logvar
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
def get_trainable_parameters(self) -> Any:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Any:
if self.learn_logvar:
yield self.logvar
yield from ()
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
regularization_log,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
split="train",
weights=None,
):
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
log.update(
{
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
)
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment