# PixArt-alpha ## 论文 `PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis` * https://arxiv.org/abs/2310.00426 ## 模型结构 该模型基于`DiT(Diffusion Transformer)`模型,添加了`Multi-Head Cross-Attention`用于对其文本与图像。 ![alt text](readme_imgs/image-1.png) ## 算法原理 模型中主要涉及`Multi-Head Self-Attention`和`Multi-Head Cross-Attention`,其中`Multi-Head Self-Attention`主要用于对图像建模,`Multi-Head Cross-Attention`用于对齐图像与文本。 ![alt text](readme_imgs/image-2.png) ## 环境配置 ### Docker(方法一) docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 docker run --shm-size 10g --network=host --name=pixart-alpha --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash pip install -r requirements.txt pip install timm --no-deps pip uninstall apex # 安装diffusers # 手动安装 git clone https://github.com/huggingface/diffusers.git cd diffusers && python setup.py install # 自动安装 pip install git+https://github.com/huggingface/diffusers ### Dockerfile(方法二) # 需要在对应的目录下 docker build -t : . docker run --shm-size 10g --network=host --name=pixart-alpha --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash pip install -r requirements.txt pip install timm --no-deps pip uninstall apex # 安装diffusers # 手动安装 git clone https://github.com/huggingface/diffusers.git cd diffusers && python setup.py install # 自动安装 pip install git+https://github.com/huggingface/diffusers ### Anaconda (方法三) 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.sourcefind.cn/tool/ DTK驱动:dtk24.04.1 python:python3.10 torch:2.1.0 torchvision:0.16.0 triton:2.1.0 Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应 2、其它非特殊库参照requirements.txt安装 pip install -r requirements.txt pip install timm --no-deps # 安装diffusers # 手动安装 git clone https://github.com/huggingface/diffusers.git cd diffusers && python setup.py install # 自动安装 pip install git+https://github.com/huggingface/diffusers ## 数据集 注意:该数据集为训练数据集 完整数据:https://ai.meta.com/datasets/segment-anything/ 测试数据:https://huggingface.co/datasets/PixArt-alpha/data_toy 数据下载完成后需要进行处理,可运行以下脚本: # 使用LLava获取更加详细的图像描述 python tools/VLM_caption_lightning.py --output output/dir/ --data-root data/root/path --index path/to/data.json # 提前生成训练需要的特征 python tools/extract_features.py --img_size=256 \ --json_path "data/data_toy/data_info.json" \ --t5_save_root "data/data_toy/caption_feature_wmask" \ --vae_save_root "data/data_toy/img_vae_features" \ --pretrained_models_dir "pretrained_models/hub/pixart_alpha" \ --dataset_root "data/data_toy/images/" 处理后获得下述数据结构 data/ └── data_toy ├── caption_feature_wmask │   ├── 0_1.npz │   └── 0_3.npz ├── captions │   ├── 0_1.txt │   └── 0_3.txt ├── data_info.json ├── images │   ├── 0_1.png │   └── 0_3.png ├── img_vae_features │   └── 256resolution │   └── noflip │   ├── 0_1.npy │   └── 0_3.npy └── partition └── part0.txt ## 训练 无 ## 推理 ### 模型下载 |Model+url|存放位置| |:---:|:---:| |[T5](https://hf-mirror.com/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl)|/path/to/save/models/pixart_alpha/t5_ckpts| |[sd-vae-ft-ema](https://hf-mirror.com/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema)|/path/to/save/models/pixart_alpha/sd-vae-ft-ema| pixart_alpha/ ├── sd-vae-ft-ema │ ├── config.json │ └── diffusion_pytorch_model.bin └── t5_ckpts └── t5-v1_1-xxl ├── config.json ├── pytorch_model-00001-of-00002.bin ├── pytorch_model-00002-of-00002.bin ├── pytorch_model.bin.index.json ├── special_tokens_map.json ├── spiece.model └── tokenizer_config.json 注意:上述模型需手动下载,其余模型将在运行时自动下载。 export HF_ENDPOINT=https://hf-mirror.com export HUB_HOME=/path/to/save/models ### 命令 # 快速测试 HIP_VISIBLE_DEVICES=0 python quick_inference_with_code.py ### WebUI # diffusers version DEMO_PORT=12345 python app/app.py ## result |prompt|output| |:---:|:---:| |a dog is playing a basketball|![alt text](readme_imgs/image-3.png)| ### 精度 无 ## 应用场景 ### 算法类别 `AIGC` ### 热点应用行业 `零售,广媒,教育` ## 源码仓库及问题反馈 * https://developer.sourcefind.cn/codes/modelzoo/pixart-alpha_pytorch ## 参考资料 * https://github.com/PixArt-alpha/PixArt-alpha