Commit 1bfbcff0 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1204 canceled with stages
PYTHONPATH=../../.. \
accelerate launch train_controlnet_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--output_dir="train_controlnet_sdxl" \
--dataset_name="AI-ModelScope/controlnet_dataset_condition_fill50k" \
--mixed_precision="fp16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=15000 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=100 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="tensorboard" \
--seed=42 \
PYTHONPATH=../../.. \
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--instance_data_dir="./dog-example" \
--output_dir="train_dreambooth" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400 \
PYTHONPATH=../../.. \
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--instance_data_dir="./dog-example" \
--output_dir="train_dreambooth_lora" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=1e-4 \
--report_to="tensorboard" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=50 \
--seed="0" \
PYTHONPATH=../../ \
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--instance_data_dir="./dog-example" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--output_dir="sdxl-dog-dreambooth-lora" \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="tensorboard" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
PYTHONPATH=../../../ \
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--checkpointing_steps=500 \
--checkpoints_total_limit 2 \
--output_dir="train_text_to_image" \
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=512 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=100 \
--checkpointing_steps=5000 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=1024 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=2 \
--checkpointing_steps=500 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora_sdxl" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--resolution=512 \
--center_crop \
--random_flip \
--proportion_empty_prompts=0.2 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=10000 \
--use_8bit_adam \
--learning_rate=1e-06 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--report_to="tensorboard" \
--validation_prompt="a cute Sundar Pichai creature" \
--validation_epochs 5 \
--checkpointing_steps=5000 \
--output_dir="train_text_to_image_sdxl" \
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_controlnet
if __name__ == '__main__':
train_controlnet()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_controlnet_sdxl
if __name__ == '__main__':
train_controlnet_sdxl()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_dreambooth
if __name__ == '__main__':
train_dreambooth()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_dreambooth_lora
if __name__ == '__main__':
train_dreambooth_lora()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_dreambooth_lora_sdxl
if __name__ == '__main__':
train_dreambooth_lora_sdxl()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_text_to_image
if __name__ == '__main__':
train_text_to_image()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_text_to_image_lora
if __name__ == '__main__':
train_text_to_image_lora()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_text_to_image_lora_sdxl
if __name__ == '__main__':
train_text_to_image_lora_sdxl()
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import train_text_to_image_sdxl
if __name__ == '__main__':
train_text_to_image_sdxl()
<h1 align="center">Stable Diffusion Example</h1>
<p align="center">
<a href="https://modelscope.cn/home">Modelscope Hub</a>
<br>
<a href="README_CN.md">中文</a>&nbsp | &nbspEnglish
</p>
## Features
1. Support Stable Diffusion [LoRA](https://arxiv.org/abs/2106.09685) method.
2. Support Stable Diffusion XL [LoRA](https://arxiv.org/abs/2106.09685) method.
## Prepare the Environment
```bash
pip install -r requirements.txt
```
## Train and Inference
```bash
# Clone the repository and enter the code directory.
git clone https://github.com/modelscope/swift.git
# Stable Diffusion LoRA
bash examples/pytorch/stable_diffusion/run_train_lora.sh
# Stable Diffusion XL LoRA
bash examples/pytorch/stable_diffusion/run_train_lora_xl.sh
```
## Extend Datasets
The [buptwq/lora-stable-diffusion-finetune](https://www.modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/summary) dataset used in the example is from [ModelScope Hub](https://www.modelscope.cn/my/overview), you can replace different datasets ID by modifying the `train_dataset_name` parameter.
In addition, you can also use local datasets. Fill in the path of the dataset file in `train_dataset_name` parameter, which needs to include a `train.csv` file to map image files and text prompts. Please organize it into the following format:
```
Text,Target:FILE
[prompt], [image dir]
......
```
Here is an example of `train.csv` file:
```
Text,Target:FILE
a dog,target/00.jpg
a dog,target/01.jpg
a dog,target/02.jpg
a dog,target/03.jpg
a dog,target/04.jpg
```
<h1 align="center">微调稳定扩散模型例子</h1>
<p align="center">
<a href="https://modelscope.cn/home">魔搭社区</a>
<br>
中文&nbsp | &nbsp<a href="README.md">English</a>
</p>
## 特性
1. 支持[LoRA](https://arxiv.org/abs/2106.09685)方法微调稳定扩散模型。
2. 支持[LoRA](https://arxiv.org/abs/2106.09685)方法微调XL版本的稳定扩散模型。
## 环境准备
```bash
pip install -r requirements.txt
```
## 训练和推理
```bash
# 克隆代码库并进入代码目录
git clone https://github.com/modelscope/swift.git
# LoRA方法微调和推理稳定扩散模型
bash examples/pytorch/stable_diffusion/run_train_lora.sh
# LoRA方法微调和推理XL版本的稳定扩散模型
bash examples/pytorch/stable_diffusion/run_train_lora_xl.sh
```
## 数据集拓展
示例中使用的数据集[buptwq/lora-stable-diffusion-finetune](https://www.modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/summary)来自[ModelScope Hub](https://www.modelscope.cn/my/overview),您可以在ModelScope Hub选择其他数据集,用被选择的数据集ID来修改`train_dataset_name`参数。
除此之外,您也可以使用本地数据集。请用本地数据集路径修改`train_dataset_name`参数,请注意在本地数据集路径中应该包含一个`train.csv`文件用来映射图片和文本提示词。`train.csv`文件请参照以下的格式:
```
Text,Target:FILE
[提示词], [图片路径]
......
```
下面是一个 `train.csv` 文件的例子:
```
Text,Target:FILE
a dog,target/00.jpg
a dog,target/01.jpg
a dog,target/02.jpg
a dog,target/03.jpg
a dog,target/04.jpg
```
import os
from dataclasses import dataclass, field
import cv2
import torch
from modelscope import get_logger, snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks
from swift import LoRAConfig, Swift
logger = get_logger()
# Load configuration file and dataset
@dataclass(init=False)
class StableDiffusionLoraArguments(TrainingArgs):
prompt: str = field(
default='dog', metadata={
'help': 'The pipeline prompt.',
})
lora_rank: int = field(
default=4, metadata={
'help': 'The rank size of lora intermediate linear.',
})
lora_alpha: int = field(
default=32, metadata={
'help': 'The factor to add the lora weights',
})
lora_dropout: float = field(
default=0.0, metadata={
'help': 'The dropout rate of the lora module',
})
bias: str = field(
default='none', metadata={
'help': 'Bias type. Values ca be "none", "all" or "lora_only"',
})
sample_nums: int = field(
default=10, metadata={
'help': 'The numbers of sample outputs',
})
num_inference_steps: int = field(
default=50, metadata={
'help': 'The number of denoising steps.',
})
training_args = StableDiffusionLoraArguments(task='text-to-image-synthesis').parse_cli()
config, args = training_args.to_config()
if os.path.exists(args.train_dataset_name):
# Load local dataset
train_dataset = MsDataset.load(args.train_dataset_name)
validation_dataset = MsDataset.load(args.train_dataset_name)
else:
# Load online dataset
train_dataset = MsDataset.load(args.train_dataset_name, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD)
validation_dataset = MsDataset.load(
args.train_dataset_name, split='validation', download_mode=DownloadMode.FORCE_REDOWNLOAD)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler = {'type': 'LambdaLR', 'lr_lambda': lambda _: 1, 'last_epoch': -1}
return cfg
# build models
model = Model.from_pretrained(training_args.model, revision=args.model_revision)
model_dir = snapshot_download(args.model)
lora_config = LoRAConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.bias,
target_modules=['to_q', 'to_k', 'to_v', 'query', 'key', 'value', 'to_out.0'])
model.unet = Swift.prepare_model(model.unet, lora_config)
# build trainer and training
kwargs = dict(
model=model,
cfg_file=os.path.join(model_dir, 'configuration.json'),
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
torch_type=torch.float32,
use_swift=True,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.stable_diffusion, default_args=kwargs)
trainer.train()
# save models
model.unet.save_pretrained(os.path.join(training_args.work_dir, 'unet'))
logger.info(f'model save pretrained {training_args.work_dir}')
# pipeline after training and save result
pipe = pipeline(
task=Tasks.text_to_image_synthesis,
model=training_args.model,
model_revision=args.model_revision,
lora_dir=os.path.join(training_args.work_dir, 'unet'),
use_swift=True)
for index in range(args.sample_nums):
image = pipe({'text': args.prompt, 'num_inference_steps': args.num_inference_steps})
cv2.imwrite(f'./lora_result_{index}.png', image['output_imgs'][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment