Commit d03ea00f authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #1898 canceled with stages
*.swp
**/__pycache__/**
**/.ipynb_checkpoints/**
.DS_Store
.idea/*
.vscode/*
llava/
_vis_cached/
_auto_*
ckpt/
log/
tb*/
img*/
local_output*
*.pth
*.pth.tar
*.ckpt
*.log
backup/
checkpoint/
dataset/
result/
\ No newline at end of file
MIT License
Copyright (c) 2024 FoundationVision
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
# VAR
## 论文
`Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction`
- https://arxiv.org/abs/2404.02905
## 模型结构
VAR模型将图像的自回归学习重新定义为从粗糙到精细的“下一尺度预测”或“下一分辨率预测”的方法,与传统的栅格扫描“下一个标记预测”不同。采用多尺度向量量化变分自编码器(VQ-VAE)将图像编码为不同分辨率的标记图(token maps),并利用视觉自回归变换器(VAR transformer)来学习图像的分布。VQ-VAE首先将图像编码成多尺度的 token maps ,自回归的过程中从 1x1 大小的 token maps 出发,逐步扩大图像尺寸。在每一步图像扩展中,VAR transformer网络基于之前不同尺寸图像生成下一个更高分辨率的 token maps 。
<div align=center>
<img src="./doc/VAR.PNG"/>
</div>
## 算法原理
1、多尺度向量量化变分自编码器(VQ-VAE):将图像编码为不同分辨率的标记图(token maps)
(a)图像经过变分自编码器得到特征图f
(b)对于每个设定的尺度,对f插值得到对应尺度图,并使用码本(通常基于最近匹配原则)对特征进行量化,将连续特征转换为离散的token
(c)多尺度的特征编码是基于残差(residual)方式进行
(d)通过嵌入(Embedding)将离散 token 转换为连续嵌入向量的过程
(e)基于多尺度特征恢复特征图f后,经过解码器得到重建图像,再通过重建损失等训练VQ-VAE
2、视觉自回归变换器(VAR transformer):学习图像的分布
(a)使用类 GPT 的解码器结构进行自回归学习,学习基于低分辨率的 token maps 预测高分辨率的 token maps
(b)使用多尺度 VQVAE 输出的多尺度 token maps 作为真值,监督模型训练
(c)同一尺度的整个 token maps 是同时生成的,而非逐个按顺序生成的
<div align=center>
<img src="./doc/VAR_detail.PNG"/>
</div>
## 代码改动说明
ps:仓库中是改动后的代码,不需再次修改
```
# 1、VAR/train.py
is_val_and_also_saving = (ep + 1) % 1 == 0 or (ep + 1) == args.ep # 训练时间过长,为防止中断修改为每轮保存一次模型,原为(ep + 1) % 10
vae_ckpt = '/path/your_code_data/VAR/vae_ch160v4096z32.pth' # 修改vae模型位置
# 2、VAR/models/basic_var.py
fused_mlp_func = None # import下添加,dcu不支持 fused_dense.linear_act_forward
memory_efficient_attention = None # import下添加,对于该仓库xformers没有适合dcu的算子
```
## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker run -it --network=host --privileged=true --name=var --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro <imageID> /bin/bash # <imageID>为以上拉取的docker的镜像ID替换
cd /your_code_path/VAR
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements.txt
wget https://download.sourcefind.cn:65024/directlink/4/flash_attn/DAS1.1/flash_attn-2.0.4+das1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl
pip install flash_attn-2.0.4+das1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl
git config --global --add safe.directory /your_code_path/VAR
```
### Dockerfile(方法二)
```
docker build --no-cache -t var:latest .
docker run -it --network=host --privileged=true --name=var --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro var /bin/bash
cd /path/your_code_data/VAR
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements.txt
wget https://download.sourcefind.cn:65024/directlink/4/flash_attn/DAS1.1/flash_attn-2.0.4+das1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl
pip install flash_attn-2.0.4+das1.1gitc7a8c18.abi1.dtk2404.torch2.1-cp310-cp310-manylinux_2_31_x86_64.whl
git config --global --add safe.directory /your_code_path/VAR
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24.04.1
python:python3.10
pytorch:2.1.0
flash_attn:2.0.4
```
`Tips:以上dtk软件栈、python、pytorch等DCU相关工具版本需要严格一一对应`
2、其他非特殊库直接按照下面步骤进行安装
```
cd /path/your_code_data/VAR
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements.txt
git config --global --add safe.directory /your_code_path/VAR
```
## 训练数据集
`imagenet-1k`
仅需要ILSVRC2012_img_train.tar、ILSVRC2012_img_val.tar作为数据集,可通过[scnet](http://113.200.138.88:18080/aidatasets/project-dependency/imagenet-2012)[官网链接](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php)进行下载,下载后的压缩包需要进行预处理。通过官网下载和处理数据集的代码如下:
ps:本仓库准备了小数据集供训练测试,数据量约为完整数据集的1/130,可通过[scnet](http://113.200.138.88:18080/aidatasets/project-dependency/var)进行下载。
```
cd /path/your_code_data/VAR
mkdir dataset
# 通过官网下载imagenet-1k的训练集和验证集
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar -P ./dataset
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar -P ./dataset
# 数据集处理:将1000个tar解压至train/val文件夹中,对应的每类图片建立自己的对应文件夹。方便pytorch读取及图片分类。
cd dataset
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train && tar -xvf ILSVRC2012_img_train.tar && rm ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar && rm ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
```
数据集目录结构如下:
```
── dataset
│   ├── train
│ │    ├── n01440764
│ │ │    ├── n01440764_18.JPEG
│ │ │ └── ...
│ │   ├── n01443537
│ │ │    ├── n01443537_2.JPEG
│ │ │ └── ...
│ │ └── ...
│   └── val
│    ├── n01440764
│ │    ├── ILSVRC2012_val_00000293.JPEG
│ │ └── ...
│   ├── n01443537
│ │    ├── ILSVRC2012_val_00000236.JPEG
│ │ └── ...
│ └── ...
```
## 训练
本仓库未提供VQ-VAE模块的训练代码,因此训练时必须下载权重vae_ch160v4096z32。权重vae_ch160v4096z32可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/var_pytorch)[huggingface](https://huggingface.co/FoundationVision/var/tree/main)进行下载,通过huggingface下载的代码如下:
```
cd /path/your_code_data/VAR
export HF_DATASETS_CACHE="./checkpoint"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download --resume-download FoundationVision/var vae_ch160v4096z32.pth --local-dir checkpoint
```
### 单机多卡
```
cd /path/your_code_data/VAR
HIP_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 train.py \
--depth=16 --bs=192 --ep=3 --fp16=1 --alng=1e-3 --wpe=0.1 \
--data_path=/home/VAR/dataset
# torchrun: 这是启动分布式训练的命令。
# --nproc_per_node=8: 指定每个节点上使用的进程数(即每个节点上参与训练的 GPU 数量)。
# --nnodes=...: 指定参与训练的节点总数。
# --node_rank=...: 指定当前节点的编号(从 0 开始)。
# --master_addr=...: 指定主节点的 IP 地址。
# --master_port=...: 指定主节点上用于通信的端口号。
# --depth=16: 设置模型的深度。
# --bs=768: 设置批处理大小。
# --ep=200: 设置训练的总轮数(epoch)。
# --fp16=1: 启用FP16训练。
# --alng=1e-3: 初始化ada_lin.w[gamma channels]
# --wpe=0.1: 训练结束时的最终lr
# --data_path:数据集地址
```
## 推理
检查点可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/var_pytorch)[huggingface](https://huggingface.co/FoundationVision/var/tree/main)进行下载,通过huggingface下载的代码如下:
```
cd /path/your_code_data/VAR
export HF_DATASETS_CACHE="./checkpoint"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download --resume-download FoundationVision/var --local-dir checkpoint
```
test.py内可以修改推理参数,其中class_labels决定了生成图像的数量和类别,第n个元素为生成的第n张图像的imagenet标签。推理代码如下:
```
cd /path/your_code_data/VAR
python test.py
```
注:imagenet标签与classes的映射文件已被整理在VAR/doc/imagenet_classes.txt、VAR/doc/imagenet_synsets.txt,用法如下:imagenet标签是980,需要先在imagenet_classes.txt中找到第981行class(因为索引从0开始,所以是第981个类别)是 n09472597,然后再在imagenet_synsets.txt中搜索序号n09472597对应的含义是volcano,即生成的图像是火山。
## result
test.py默认推理结果为:
<div align=center>
<img src="./doc/inference.png"/>
</div>
### 精度
以下为默认训练结果:
| | 测试参数 | 软件栈 | final loss |
| -------------------------------- | ------------------------------------------------ | ---------- | ---------- |
| A800 * 4<br/>(80G,1410 Mhz) | depth=16<br/>ep=3<br/>bs=192<br>fp16=1<br/>ep=3 | cuda11.8 | 7.173136 |
| k100ai * 4<br/>(64G,1400 Mhz) | depth=16<br/>ep=3<br/>bs=192<br/>fp16=1<br/>ep=3 | dtk24.04.1 | 7.151511 |
## 应用场景
### 算法类别
`以文生图`
### 热点应用行业
`家具,电商,医疗,广媒,教育`
## 预训练权重
- http://113.200.138.88:18080/aimodels/findsource-dependency/var_pytorch
- https://huggingface.co/FoundationVision/var/tree/main
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/var_pytorch
## 参考资料
- https://github.com/FoundationVision/VAR
# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
<div align="center">
[![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://var.vision/demo)&nbsp;
[![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)&nbsp;
[![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)&nbsp;
[![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
</div>
<p align="center" style="font-size: larger;">
<a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
</p>
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
<p>
<br>
## 🕹️ Try and Play with VAR!
We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!
We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
[//]: # (<p align="center">)
[//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
[//]: # (<p>)
## What's New?
### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
<p>
### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
<p>
### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
<p>
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
<p>
### 🔥 Zero-shot generalizability🛠️:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
<p>
#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
## VAR zoo
We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
| model | reso. | FID | rel. cost | #params | HF weights🤗 |
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
| VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
| VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
| VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
| VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
| VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
## Installation
1. Install `torch>=2.0.0`.
2. Install other pip packages via `pip3 install -r requirements.txt`.
3. Prepare the [ImageNet](http://image-net.org/) dataset
<details>
<summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
```
/path/to/imagenet/:
train/:
n01440764:
many_images.JPEG ...
n01443537:
many_images.JPEG ...
val/:
n01440764:
ILSVRC2012_val_00000293.JPEG ...
n01443537:
ILSVRC2012_val_00000236.JPEG ...
```
**NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
</details>
5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
## Training Scripts
To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
```shell
# d16, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
# d20, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
# d24, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
# d30, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
```
A folder named `local_output` will be created to save the checkpoints and logs.
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
## Sampling & Zero-shot Inference
For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
We'll provide the sampling script later.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Citation
If our work assists your research, feel free to give us a star ⭐ or cite us using:
```
@Article{VAR,
title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
year={2024},
eprint={2404.02905},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"################## 1. Download checkpoints and build models\n",
"import os\n",
"import os.path as osp\n",
"import torch, torchvision\n",
"import random\n",
"import numpy as np\n",
"import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
"setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
"setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
"from models import VQVAE, build_vae_var\n",
"\n",
"MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====\n",
"assert MODEL_DEPTH in {16, 20, 24, 30}\n",
"\n",
"\n",
"# download checkpoint\n",
"hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
"vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'\n",
"if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')\n",
"if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')\n",
"\n",
"# build vae, var\n",
"patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"if 'vae' not in globals() or 'var' not in globals():\n",
" vae, var = build_vae_var(\n",
" V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters\n",
" device=device, patch_nums=patch_nums,\n",
" num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
" )\n",
"\n",
"# load checkpoints\n",
"vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)\n",
"var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)\n",
"vae.eval(), var.eval()\n",
"for p in vae.parameters(): p.requires_grad_(False)\n",
"for p in var.parameters(): p.requires_grad_(False)\n",
"print(f'prepare finished.')"
],
"metadata": {
"collapsed": false,
"is_executing": true
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"############################# 2. Sample with classifier-free guidance\n",
"\n",
"# set args\n",
"seed = 0 #@param {type:\"number\"}\n",
"torch.manual_seed(seed)\n",
"num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
"cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
"class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:\"raw\"}\n",
"more_smooth = False # True for more smooth output\n",
"\n",
"# seed\n",
"torch.manual_seed(seed)\n",
"random.seed(seed)\n",
"np.random.seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False\n",
"\n",
"# run faster\n",
"tf32 = True\n",
"torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
"torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
"torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
"\n",
"# sample\n",
"B = len(class_labels)\n",
"label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
"with torch.inference_mode():\n",
" with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # using bfloat16 can be faster\n",
" recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
"\n",
"chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
"chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
"chw = PImage.fromarray(chw.astype(np.uint8))\n",
"chw.show()\n"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
import datetime
import functools
import os
import sys
from typing import List
from typing import Union
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
__initialized = False
def initialized():
return __initialized
def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
global __device
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
elif 'RANK' not in os.environ:
torch.cuda.set_device(gpu_id_if_not_distibuted)
__device = torch.empty(1).cuda().device
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
return
# then 'RANK' must exist
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
local_rank = global_rank % num_gpus
torch.cuda.set_device(local_rank)
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
method = 'fork' if fork else 'spawn'
print(f'[dist initialize] mp method={method}')
mp.set_start_method(method)
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
global __rank, __local_rank, __world_size, __initialized
__local_rank = local_rank
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
__device = torch.empty(1).cuda().device
__initialized = True
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
def get_rank():
return __rank
def get_local_rank():
return __local_rank
def get_world_size():
return __world_size
def get_device():
return __device
def set_gpu_id(gpu_id: int):
if gpu_id is None: return
global __device
if isinstance(gpu_id, (str, int)):
torch.cuda.set_device(int(gpu_id))
__device = torch.empty(1).cuda().device
else:
raise NotImplementedError
def is_master():
return __rank == 0
def is_local_master():
return __local_rank == 0
def new_group(ranks: List[int]):
if __initialized:
return tdist.new_group(ranks=ranks)
return None
def barrier():
if __initialized:
tdist.barrier()
def allreduce(t: torch.Tensor, async_op=False):
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
ret = tdist.all_reduce(cu, async_op=async_op)
t.copy_(cu.cpu())
else:
ret = tdist.all_reduce(t, async_op=async_op)
return ret
return None
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()
t_size = torch.tensor(t.size(), device=t.device)
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
tdist.all_gather(ls_size, t_size)
max_B = max(size[0].item() for size in ls_size)
pad = max_B - t_size[0].item()
if pad:
pad_size = (pad, *t.size()[1:])
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls_padded, t)
ls = []
for t, size in zip(ls_padded, ls_size):
ls.append(t[:size[0].item()])
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls
def broadcast(t: torch.Tensor, src_rank) -> None:
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
if not initialized():
return torch.tensor([val]) if fmt is None else [fmt % val]
ts = torch.zeros(__world_size)
ts[__rank] = val
allreduce(ts)
if fmt is None:
return ts
return [fmt % v for v in ts.cpu().numpy().tolist()]
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper
def local_master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_local_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper
def for_visualize(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master():
# with torch.no_grad():
ret = func(*args, **kwargs)
else:
ret = None
return ret
return wrapper
def finalize():
if __initialized:
tdist.destroy_process_group()
n01440764
n01443537
n01484850
n01491361
n01494475
n01496331
n01498041
n01514668
n01514859
n01518878
n01530575
n01531178
n01532829
n01534433
n01537544
n01558993
n01560419
n01580077
n01582220
n01592084
n01601694
n01608432
n01614925
n01616318
n01622779
n01629819
n01630670
n01631663
n01632458
n01632777
n01641577
n01644373
n01644900
n01664065
n01665541
n01667114
n01667778
n01669191
n01675722
n01677366
n01682714
n01685808
n01687978
n01688243
n01689811
n01692333
n01693334
n01694178
n01695060
n01697457
n01698640
n01704323
n01728572
n01728920
n01729322
n01729977
n01734418
n01735189
n01737021
n01739381
n01740131
n01742172
n01744401
n01748264
n01749939
n01751748
n01753488
n01755581
n01756291
n01768244
n01770081
n01770393
n01773157
n01773549
n01773797
n01774384
n01774750
n01775062
n01776313
n01784675
n01795545
n01796340
n01797886
n01798484
n01806143
n01806567
n01807496
n01817953
n01818515
n01819313
n01820546
n01824575
n01828970
n01829413
n01833805
n01843065
n01843383
n01847000
n01855032
n01855672
n01860187
n01871265
n01872401
n01873310
n01877812
n01882714
n01883070
n01910747
n01914609
n01917289
n01924916
n01930112
n01943899
n01944390
n01945685
n01950731
n01955084
n01968897
n01978287
n01978455
n01980166
n01981276
n01983481
n01984695
n01985128
n01986214
n01990800
n02002556
n02002724
n02006656
n02007558
n02009229
n02009912
n02011460
n02012849
n02013706
n02017213
n02018207
n02018795
n02025239
n02027492
n02028035
n02033041
n02037110
n02051845
n02056570
n02058221
n02066245
n02071294
n02074367
n02077923
n02085620
n02085782
n02085936
n02086079
n02086240
n02086646
n02086910
n02087046
n02087394
n02088094
n02088238
n02088364
n02088466
n02088632
n02089078
n02089867
n02089973
n02090379
n02090622
n02090721
n02091032
n02091134
n02091244
n02091467
n02091635
n02091831
n02092002
n02092339
n02093256
n02093428
n02093647
n02093754
n02093859
n02093991
n02094114
n02094258
n02094433
n02095314
n02095570
n02095889
n02096051
n02096177
n02096294
n02096437
n02096585
n02097047
n02097130
n02097209
n02097298
n02097474
n02097658
n02098105
n02098286
n02098413
n02099267
n02099429
n02099601
n02099712
n02099849
n02100236
n02100583
n02100735
n02100877
n02101006
n02101388
n02101556
n02102040
n02102177
n02102318
n02102480
n02102973
n02104029
n02104365
n02105056
n02105162
n02105251
n02105412
n02105505
n02105641
n02105855
n02106030
n02106166
n02106382
n02106550
n02106662
n02107142
n02107312
n02107574
n02107683
n02107908
n02108000
n02108089
n02108422
n02108551
n02108915
n02109047
n02109525
n02109961
n02110063
n02110185
n02110341
n02110627
n02110806
n02110958
n02111129
n02111277
n02111500
n02111889
n02112018
n02112137
n02112350
n02112706
n02113023
n02113186
n02113624
n02113712
n02113799
n02113978
n02114367
n02114548
n02114712
n02114855
n02115641
n02115913
n02116738
n02117135
n02119022
n02119789
n02120079
n02120505
n02123045
n02123159
n02123394
n02123597
n02124075
n02125311
n02127052
n02128385
n02128757
n02128925
n02129165
n02129604
n02130308
n02132136
n02133161
n02134084
n02134418
n02137549
n02138441
n02165105
n02165456
n02167151
n02168699
n02169497
n02172182
n02174001
n02177972
n02190166
n02206856
n02219486
n02226429
n02229544
n02231487
n02233338
n02236044
n02256656
n02259212
n02264363
n02268443
n02268853
n02276258
n02277742
n02279972
n02280649
n02281406
n02281787
n02317335
n02319095
n02321529
n02325366
n02326432
n02328150
n02342885
n02346627
n02356798
n02361337
n02363005
n02364673
n02389026
n02391049
n02395406
n02396427
n02397096
n02398521
n02403003
n02408429
n02410509
n02412080
n02415577
n02417914
n02422106
n02422699
n02423022
n02437312
n02437616
n02441942
n02442845
n02443114
n02443484
n02444819
n02445715
n02447366
n02454379
n02457408
n02480495
n02480855
n02481823
n02483362
n02483708
n02484975
n02486261
n02486410
n02487347
n02488291
n02488702
n02489166
n02490219
n02492035
n02492660
n02493509
n02493793
n02494079
n02497673
n02500267
n02504013
n02504458
n02509815
n02510455
n02514041
n02526121
n02536864
n02606052
n02607072
n02640242
n02641379
n02643566
n02655020
n02666196
n02667093
n02669723
n02672831
n02676566
n02687172
n02690373
n02692877
n02699494
n02701002
n02704792
n02708093
n02727426
n02730930
n02747177
n02749479
n02769748
n02776631
n02777292
n02782093
n02783161
n02786058
n02787622
n02788148
n02790996
n02791124
n02791270
n02793495
n02794156
n02795169
n02797295
n02799071
n02802426
n02804414
n02804610
n02807133
n02808304
n02808440
n02814533
n02814860
n02815834
n02817516
n02823428
n02823750
n02825657
n02834397
n02835271
n02837789
n02840245
n02841315
n02843684
n02859443
n02860847
n02865351
n02869837
n02870880
n02871525
n02877765
n02879718
n02883205
n02892201
n02892767
n02894605
n02895154
n02906734
n02909870
n02910353
n02916936
n02917067
n02927161
n02930766
n02939185
n02948072
n02950826
n02951358
n02951585
n02963159
n02965783
n02966193
n02966687
n02971356
n02974003
n02977058
n02978881
n02979186
n02980441
n02981792
n02988304
n02992211
n02992529
n02999410
n03000134
n03000247
n03000684
n03014705
n03016953
n03017168
n03018349
n03026506
n03028079
n03032252
n03041632
n03042490
n03045698
n03047690
n03062245
n03063599
n03063689
n03065424
n03075370
n03085013
n03089624
n03095699
n03100240
n03109150
n03110669
n03124043
n03124170
n03125729
n03126707
n03127747
n03127925
n03131574
n03133878
n03134739
n03141823
n03146219
n03160309
n03179701
n03180011
n03187595
n03188531
n03196217
n03197337
n03201208
n03207743
n03207941
n03208938
n03216828
n03218198
n03220513
n03223299
n03240683
n03249569
n03250847
n03255030
n03259280
n03271574
n03272010
n03272562
n03290653
n03291819
n03297495
n03314780
n03325584
n03337140
n03344393
n03345487
n03347037
n03355925
n03372029
n03376595
n03379051
n03384352
n03388043
n03388183
n03388549
n03393912
n03394916
n03400231
n03404251
n03417042
n03424325
n03425413
n03443371
n03444034
n03445777
n03445924
n03447447
n03447721
n03450230
n03452741
n03457902
n03459775
n03461385
n03467068
n03476684
n03476991
n03478589
n03481172
n03482405
n03483316
n03485407
n03485794
n03492542
n03494278
n03495258
n03496892
n03498962
n03527444
n03529860
n03530642
n03532672
n03534580
n03535780
n03538406
n03544143
n03584254
n03584829
n03590841
n03594734
n03594945
n03595614
n03598930
n03599486
n03602883
n03617480
n03623198
n03627232
n03630383
n03633091
n03637318
n03642806
n03649909
n03657121
n03658185
n03661043
n03662601
n03666591
n03670208
n03673027
n03676483
n03680355
n03690938
n03691459
n03692522
n03697007
n03706229
n03709823
n03710193
n03710637
n03710721
n03717622
n03720891
n03721384
n03724870
n03729826
n03733131
n03733281
n03733805
n03742115
n03743016
n03759954
n03761084
n03763968
n03764736
n03769881
n03770439
n03770679
n03773504
n03775071
n03775546
n03776460
n03777568
n03777754
n03781244
n03782006
n03785016
n03786901
n03787032
n03788195
n03788365
n03791053
n03792782
n03792972
n03793489
n03794056
n03796401
n03803284
n03804744
n03814639
n03814906
n03825788
n03832673
n03837869
n03838899
n03840681
n03841143
n03843555
n03854065
n03857828
n03866082
n03868242
n03868863
n03871628
n03873416
n03874293
n03874599
n03876231
n03877472
n03877845
n03884397
n03887697
n03888257
n03888605
n03891251
n03891332
n03895866
n03899768
n03902125
n03903868
n03908618
n03908714
n03916031
n03920288
n03924679
n03929660
n03929855
n03930313
n03930630
n03933933
n03935335
n03937543
n03938244
n03942813
n03944341
n03947888
n03950228
n03954731
n03956157
n03958227
n03961711
n03967562
n03970156
n03976467
n03976657
n03977966
n03980874
n03982430
n03983396
n03991062
n03992509
n03995372
n03998194
n04004767
n04005630
n04008634
n04009552
n04019541
n04023962
n04026417
n04033901
n04033995
n04037443
n04039381
n04040759
n04041544
n04044716
n04049303
n04065272
n04067472
n04069434
n04070727
n04074963
n04081281
n04086273
n04090263
n04099969
n04111531
n04116512
n04118538
n04118776
n04120489
n04125021
n04127249
n04131690
n04133789
n04136333
n04141076
n04141327
n04141975
n04146614
n04147183
n04149813
n04152593
n04153751
n04154565
n04162706
n04179913
n04192698
n04200800
n04201297
n04204238
n04204347
n04208210
n04209133
n04209239
n04228054
n04229816
n04235860
n04238763
n04239074
n04243546
n04251144
n04252077
n04252225
n04254120
n04254680
n04254777
n04258138
n04259630
n04263257
n04264628
n04265275
n04266014
n04270147
n04273569
n04275548
n04277352
n04285008
n04286575
n04296562
n04310018
n04311004
n04311174
n04317175
n04325704
n04326547
n04328186
n04330267
n04332243
n04335435
n04336792
n04344873
n04346328
n04347754
n04350905
n04355338
n04355933
n04356056
n04357314
n04366367
n04367480
n04370456
n04371430
n04371774
n04372370
n04376876
n04380533
n04389033
n04392985
n04398044
n04399382
n04404412
n04409515
n04417672
n04418357
n04423845
n04428191
n04429376
n04435653
n04442312
n04443257
n04447861
n04456115
n04458633
n04461696
n04462240
n04465501
n04467665
n04476259
n04479046
n04482393
n04483307
n04485082
n04486054
n04487081
n04487394
n04493381
n04501370
n04505470
n04507155
n04509417
n04515003
n04517823
n04522168
n04523525
n04525038
n04525305
n04532106
n04532670
n04536866
n04540053
n04542943
n04548280
n04548362
n04550184
n04552348
n04553703
n04554684
n04557648
n04560804
n04562935
n04579145
n04579432
n04584207
n04589890
n04590129
n04591157
n04591713
n04592741
n04596742
n04597913
n04599235
n04604644
n04606251
n04612504
n04613696
n06359193
n06596364
n06785654
n06794110
n06874185
n07248320
n07565083
n07579787
n07583066
n07584110
n07590611
n07613480
n07614500
n07615774
n07684084
n07693725
n07695742
n07697313
n07697537
n07711569
n07714571
n07714990
n07715103
n07716358
n07716906
n07717410
n07717556
n07718472
n07718747
n07720875
n07730033
n07734744
n07742313
n07745940
n07747607
n07749582
n07753113
n07753275
n07753592
n07754684
n07760859
n07768694
n07802026
n07831146
n07836838
n07860988
n07871810
n07873807
n07875152
n07880968
n07892512
n07920052
n07930864
n07932039
n09193705
n09229709
n09246464
n09256479
n09288635
n09332890
n09399592
n09421951
n09428293
n09468604
n09472597
n09835506
n10148035
n10565667
n11879895
n11939491
n12057211
n12144580
n12267677
n12620546
n12768682
n12985857
n12998815
n13037406
n13040303
n13044778
n13052670
n13054560
n13133613
n15075141
\ No newline at end of file
This diff is collapsed.
icon.png

64.5 KB

# 模型唯一标识
modelCode = 1091
# 模型名称
modelName= var_pytorch
# 模型描述
modelDescription=VAR是类GPT风格的模型,将图像的自回归学习重新定义为从粗糙到精细的“下一尺度预测”或“下一分辨率预测”的方法,与传统的栅格扫描“下一个标记预测”不同。
# 应用场景
appScenario=推理,训练,以文生图,家具,电商,医疗,广媒,教育
# 框架类型
frameType=pytorch
\ No newline at end of file
from typing import Tuple
import torch.nn as nn
from .quant import VectorQuantizer2
from .var import VAR
from .vqvae import VQVAE
def build_vae_var(
# Shared args
device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
# VQVAE args
V=4096, Cvae=32, ch=160, share_quant_resi=4,
# VAR args
num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
flash_if_available=True, fused_if_available=True,
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
) -> Tuple[VQVAE, VAR]:
heads = depth
width = depth * 64
dpr = 0.1 * depth/24
# disable built-in initialization for speed
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
setattr(clz, 'reset_parameters', lambda self: None)
# build models
vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
var_wo_ddp = VAR(
vae_local=vae_local,
num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
).to(device)
var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
return vae_local, var_wo_ddp
import torch
import torch.nn as nn
import torch.nn.functional as F
# this file only provides the 2 modules used in VQVAE
__all__ = ['Encoder', 'Decoder',]
"""
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
"""
# swish
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample2x(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
class Downsample2x(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.nin_shortcut = nn.Identity()
def forward(self, x):
h = self.conv1(F.silu(self.norm1(x), inplace=True))
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
return self.nin_shortcut(x) + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.C = in_channels
self.norm = Normalize(in_channels)
self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
self.w_ratio = int(in_channels) ** (-0.5)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
qkv = self.qkv(self.norm(x))
B, _, H, W = qkv.shape # should be B,3C,H,W
C = self.C
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
# compute attention
q = q.view(B, C, H * W).contiguous()
q = q.permute(0, 2, 1).contiguous() # B,HW,C
k = k.view(B, C, H * W).contiguous() # B,C,HW
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
w = F.softmax(w, dim=2)
# attend to values
v = v.view(B, C, H * W).contiguous()
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
h = h.view(B, C, H, W).contiguous()
return x + self.proj_out(h)
def make_attn(in_channels, using_sa=True):
return AttnBlock(in_channels) if using_sa else nn.Identity()
class Encoder(nn.Module):
def __init__(
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
dropout=0.0, in_channels=3,
z_channels, double_z=False, using_sa=True, using_mid_sa=True,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
block_in = block_out
if i_level == self.num_resolutions - 1 and using_sa:
attn.append(make_attn(block_in, using_sa=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2x(block_in)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
# end
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
return h
class Decoder(nn.Module):
def __init__(
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
dropout=0.0, in_channels=3, # in_channels: raw img channels
z_channels, using_sa=True, using_mid_sa=True,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
block_in = block_out
if i_level == self.num_resolutions-1 and using_sa:
attn.append(make_attn(block_in, using_sa=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2x(block_in)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# z to block_in
# middle
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
return h
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.helpers import DropPath, drop_path
# this file only provides the 3 blocks used in VAR transformer
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
# automatically import fused operators
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.fused_dense import fused_mlp_func
except ImportError: pass
# automatically import faster attention implementations
try: from xformers.ops import memory_efficient_attention
except ImportError: pass
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
except ImportError: pass
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
except ImportError:
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
if attn_mask is not None: attn.add_(attn_mask)
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
fused_mlp_func = None # TODO:dcu不支持fused_dense.linear_act_forward
memory_efficient_attention = None #TODO:xformers没有适合dcu的算子
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
super().__init__()
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
def forward(self, x):
if self.fused_mlp_func is not None:
return self.drop(self.fused_mlp_func(
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
heuristic=0, process_group=None,
))
else:
return self.drop(self.fc2( self.act(self.fc1(x)) ))
def extra_repr(self) -> str:
return f'fused_mlp_func={self.fused_mlp_func is not None}'
class SelfAttention(nn.Module):
def __init__(
self, block_idx, embed_dim=768, num_heads=12,
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
):
super().__init__()
assert embed_dim % num_heads == 0
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
self.attn_l2_norm = attn_l2_norm
if self.attn_l2_norm:
self.scale = 1
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
self.max_scale_mul = torch.log(torch.tensor(100)).item()
else:
self.scale = 0.25 / math.sqrt(self.head_dim)
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
self.attn_drop: float = attn_drop
self.using_flash = flash_if_available and flash_attn_func is not None
self.using_xform = flash_if_available and memory_efficient_attention is not None
# only used during inference
self.caching, self.cached_k, self.cached_v = False, None, None
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, attn_bias):
B, L, C = x.shape
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
main_type = qkv.dtype
# qkv: BL3Hc
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
if self.attn_l2_norm:
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
q = F.normalize(q, dim=-1).mul(scale_mul)
k = F.normalize(k, dim=-1)
if self.caching:
if self.cached_k is None: self.cached_k = k; self.cached_v = v
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
dropout_p = self.attn_drop if self.training else 0.0
if using_flash:
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
elif self.using_xform:
# MemoryEfficientAttentionCkOp
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
else:
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
return self.proj_drop(self.proj(oup))
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
# attn = self.attn_drop(attn.softmax(dim=-1))
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
def extra_repr(self) -> str:
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
class AdaLNSelfAttn(nn.Module):
def __init__(
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
flash_if_available=False, fused_if_available=True,
):
super(AdaLNSelfAttn, self).__init__()
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
self.C, self.D = embed_dim, cond_dim
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
self.shared_aln = shared_aln
if self.shared_aln:
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
else:
lin = nn.Linear(cond_dim, 6*embed_dim)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
self.fused_add_norm_fn = None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
if self.shared_aln:
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
else:
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
return x
def extra_repr(self) -> str:
return f'shared_aln={self.shared_aln}'
class AdaLNBeforeHead(nn.Module):
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
super().__init__()
self.C, self.D = C, D
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
import torch
from torch import nn as nn
from torch.nn import functional as F
def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
B, l, V = logits_BlV.shape
if top_k > 0:
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
if top_p > 0:
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
sorted_idx_to_remove[..., -1:] = False
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
replacement = num_samples >= 0
num_samples = abs(num_samples)
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
if rng is None:
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
gumbels = (logits + gumbels) / tau
y_soft = gumbels.softmax(dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
if drop_prob == 0. or not training: return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module): # taken from timm
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'(drop_prob=...)'
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import distributed as tdist, nn as nn
from torch.nn import functional as F
import dist
# this file only provides the VectorQuantizer2 used in VQVAE
__all__ = ['VectorQuantizer2',]
class VectorQuantizer2(nn.Module):
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
def __init__(
self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
):
super().__init__()
self.vocab_size: int = vocab_size
self.Cvae: int = Cvae
self.using_znorm: bool = using_znorm
self.v_patch_nums: Tuple[int] = v_patch_nums
self.quant_resi_ratio = quant_resi
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))
self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
self.record_hit = 0
self.beta: float = beta
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
# only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
def eini(self, eini):
if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
def extra_repr(self) -> str:
return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
# ===================== `forward` is only used in VAE training =====================
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
dtype = f_BChw.dtype
if dtype != torch.float32: f_BChw = f_BChw.float()
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
with torch.cuda.amp.autocast(enabled=False):
mean_vq_loss: torch.Tensor = 0.0
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
SN = len(self.v_patch_nums)
for si, pn in enumerate(self.v_patch_nums): # from small to large
# find the nearest embedding
if self.using_znorm:
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
rest_NC = F.normalize(rest_NC, dim=-1)
idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
else:
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
if self.training:
if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
# calc loss
idx_Bhw = idx_N.view(B, pn, pn)
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat = f_hat + h_BChw
f_rest -= h_BChw
if self.training and dist.initialized():
handler.wait()
if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)
elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
self.record_hit += 1
vocab_hit_V.add_(hit_V)
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
mean_vq_loss *= 1. / SN
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
# margin = pn*pn / 100
if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]
else: usages = None
return f_hat, usages, mean_vq_loss
# ===================== `forward` is only used in VAE training =====================
def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
ls_f_hat_BChw = []
B = ms_h_BChw[0].shape[0]
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
if all_to_max_scale:
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
for si, pn in enumerate(self.v_patch_nums): # from small to large
h_BChw = ms_h_BChw[si]
if si < len(self.v_patch_nums) - 1:
h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h_BChw)
if last_one: ls_f_hat_BChw = f_hat
else: ls_f_hat_BChw.append(f_hat.clone())
else:
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
# WARNING: this should only be used for experimental purpose
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)
for si, pn in enumerate(self.v_patch_nums): # from small to large
f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')
h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])
f_hat.add_(h_BChw)
if last_one: ls_f_hat_BChw = f_hat
else: ls_f_hat_BChw.append(f_hat)
return ls_f_hat_BChw
def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
f_hat_or_idx_Bl: List[torch.Tensor] = []
patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)] # from small to large
assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
SN = len(patch_hws)
for si, (ph, pw) in enumerate(patch_hws): # from small to large
if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
# find the nearest embedding
z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
if self.using_znorm:
z_NC = F.normalize(z_NC, dim=-1)
idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
else:
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
idx_Bhw = idx_N.view(B, ph, pw)
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h_BChw)
f_rest.sub_(h_BChw)
f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))
return f_hat_or_idx_Bl
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = self.Cvae
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = self.v_patch_nums[0]
for si in range(SN-1):
if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
pn_next = self.v_patch_nums[si+1]
next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
HW = self.v_patch_nums[-1]
if si != SN-1:
h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic')) # conv after upsample
f_hat.add_(h)
return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')
else:
h = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h)
return f_hat, f_hat
class Phi(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class PhiShared(nn.Module):
def __init__(self, qresi: Phi):
super().__init__()
self.qresi: Phi = qresi
def __getitem__(self, _) -> Phi:
return self.qresi
class PhiPartiallyShared(nn.Module):
def __init__(self, qresi_ls: nn.ModuleList):
super().__init__()
self.qresi_ls = qresi_ls
K = len(qresi_ls)
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
class PhiNonShared(nn.ModuleList):
def __init__(self, qresi: List):
super().__init__(qresi)
# self.qresi = qresi
K = len(qresi)
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.vqvae import VQVAE, VectorQuantizer2
class SharedAdaLin(nn.Linear):
def forward(self, cond_BD):
C = self.weight.shape[0] // 6
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
class VAR(nn.Module):
def __init__(
self, vae_local: VQVAE,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
super().__init__()
# 0. hyperparameters
assert embed_dim % num_heads == 0
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
self.cond_drop_rate = cond_drop_rate
self.prog_si = -1 # progressive training
self.patch_nums: Tuple[int] = patch_nums
self.L = sum(pn ** 2 for pn in self.patch_nums)
self.first_l = self.patch_nums[0] ** 2
self.begin_ends = []
cur = 0
for i, pn in enumerate(self.patch_nums):
self.begin_ends.append((cur, cur+pn ** 2))
cur += pn ** 2
self.num_stages_minus_1 = len(self.patch_nums) - 1
self.rng = torch.Generator(device=dist.get_device())
# 1. input (word) embedding
quant: VectorQuantizer2 = vae_local.quantize
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
self.word_embed = nn.Linear(self.Cvae, self.C)
# 2. class embedding
init_std = math.sqrt(1 / self.C / 3)
self.num_classes = num_classes
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
# 3. absolute position embedding
pos_1LC = []
for i, pn in enumerate(self.patch_nums):
pe = torch.empty(1, pn*pn, self.C)
nn.init.trunc_normal_(pe, mean=0, std=init_std)
pos_1LC.append(pe)
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
self.pos_1LC = nn.Parameter(pos_1LC)
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
# 4. backbone blocks
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.drop_path_rate = drop_path_rate
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
self.blocks = nn.ModuleList([
AdaLNSelfAttn(
cond_dim=self.D, shared_aln=shared_aln,
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
attn_l2_norm=attn_l2_norm,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)
for block_idx in range(depth)
])
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
print(
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
end='\n\n', flush=True
)
# 5. attention mask used in training (for masking out the future)
# it won't be used in inference, since kv cache is enabled
d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
dT = d.transpose(1, 2) # dT: 11L
lvl_1L = dT[:, 0].contiguous()
self.register_buffer('lvl_1L', lvl_1L)
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
# 6. classifier head
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
self.head = nn.Linear(self.C, self.V)
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
if not isinstance(h_or_h_and_residual, torch.Tensor):
h, resi = h_or_h_and_residual # fused_add_norm must be used
h = resi + self.blocks[-1].drop_path(h)
else: # fused_add_norm is not used
h = h_or_h_and_residual
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
@torch.no_grad()
def autoregressive_infer_cfg(
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
more_smooth=False,
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
"""
only used for inference, on autoregressive mode
:param B: batch size
:param label_B: imagenet label; if None, randomly sampled
:param g_seed: random seed
:param cfg: classifier-free guidance ratio
:param top_k: top-k sampling
:param top_p: top-p sampling
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
:return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
"""
if g_seed is None: rng = None
else: self.rng.manual_seed(g_seed); rng = self.rng
if label_B is None:
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
elif isinstance(label_B, int):
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
cur_L = 0
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
for b in self.blocks: b.attn.kv_caching(True)
for si, pn in enumerate(self.patch_nums): # si: i-th segment
ratio = si / self.num_stages_minus_1
# last_L = cur_L
cur_L += pn*pn
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
x = next_token_map
AdaLNSelfAttn.forward
for b in self.blocks:
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
logits_BlV = self.get_logits(x, cond_BD)
t = cfg * ratio
logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
if not more_smooth: # this is the default case
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
else: # not used when evaluating FID/IS/Precision/Recall
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
if si != self.num_stages_minus_1: # prepare for next stage
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
for b in self.blocks: b.attn.kv_caching(False)
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
"""
:param label_B: label_B
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
:return: logits BLV, V is vocab_size
"""
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
B = x_BLCv_wo_first_l.shape[0]
with torch.cuda.amp.autocast(enabled=False):
label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
sos = cond_BD = self.class_emb(label_B)
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
if self.prog_si == 0: x_BLC = sos
else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
# hack: get the dtype if mixed precision is used
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
AdaLNSelfAttn.forward
for i, b in enumerate(self.blocks):
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
if self.prog_si == 0:
if isinstance(self.word_embed, nn.Linear):
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
else:
s = 0
for p in self.word_embed.parameters():
if p.requires_grad:
s += p.view(-1)[0] * 0
x_BLC[0, 0, 0] += s
return x_BLC # logits BLV, V is vocab_size
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
for m in self.modules():
with_weight = hasattr(m, 'weight') and m.weight is not None
with_bias = hasattr(m, 'bias') and m.bias is not None
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if with_bias: m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if with_weight: m.weight.data.fill_(1.)
if with_bias: m.bias.data.zero_()
# conv: VAR has no conv, only VQVAE has conv
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
if with_bias: m.bias.data.zero_()
if init_head >= 0:
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_head)
self.head.bias.data.zero_()
elif isinstance(self.head, nn.Sequential):
self.head[-1].weight.data.mul_(init_head)
self.head[-1].bias.data.zero_()
if isinstance(self.head_nm, AdaLNBeforeHead):
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
self.head_nm.ada_lin[-1].bias.data.zero_()
depth = len(self.blocks)
for block_idx, sab in enumerate(self.blocks):
sab: AdaLNSelfAttn
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
nn.init.ones_(sab.ffn.fcg.bias)
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
if hasattr(sab, 'ada_lin'):
sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
sab.ada_lin[-1].bias.data.zero_()
elif hasattr(sab, 'ada_gss'):
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
def extra_repr(self):
return f'drop_path_rate={self.drop_path_rate:g}'
class VARHF(VAR, PyTorchModelHubMixin):
# repo_url="https://github.com/FoundationVision/VAR",
# tags=["image-generation"]):
def __init__(
self,
vae_kwargs,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
vae_local = VQVAE(**vae_kwargs)
super().__init__(
vae_local=vae_local,
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)
"""
References:
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
"""
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from .basic_vae import Decoder, Encoder
from .quant import VectorQuantizer2
class VQVAE(nn.Module):
def __init__(
self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
beta=0.25, # commitment loss weight
using_znorm=False, # whether to normalize when computing the nearest neighbors
quant_conv_ks=3, # quant conv kernel size
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
test_mode=True,
):
super().__init__()
self.test_mode = test_mode
self.V, self.Cvae = vocab_size, z_channels
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
ddconfig = dict(
dropout=dropout, ch=ch, z_channels=z_channels,
in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
# resamp_with_conv=True, # always True, removed.
)
ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
self.encoder = Encoder(double_z=False, **ddconfig)
self.decoder = Decoder(**ddconfig)
self.vocab_size = vocab_size
self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
self.quantize: VectorQuantizer2 = VectorQuantizer2(
vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
)
self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
if self.test_mode:
self.eval()
[p.requires_grad_(False) for p in self.parameters()]
# ===================== `forward` is only used in VAE training =====================
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
VectorQuantizer2.forward
f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
# ===================== `forward` is only used in VAE training =====================
def fhat_to_img(self, f_hat: torch.Tensor):
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
f = self.quant_conv(self.encoder(inp_img_no_grad))
return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
B = ms_idx_Bl[0].shape[0]
ms_h_BChw = []
for idx_Bl in ms_idx_Bl:
l = idx_Bl.shape[1]
pn = round(l ** 0.5)
ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
if last_one:
return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
else:
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
f = self.quant_conv(self.encoder(x))
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
if last_one:
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
else:
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment