train_causal_video_vae.sh 3.4 KB
Newer Older
mashun's avatar
mashun committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#!/bin/bash

# This script is used for Causal VAE Training
# It undergoes a two-stage training
# Stage-1: image and video mixed training
# Stage-2: pure video training, using context parallel to load video with more video frames (up to 257 frames)

# GPUS=8  # The gpu number
# VAE_MODEL_PATH=PATH/vae_ckpt   # The vae model dir
# LPIPS_CKPT=vgg_lpips.pth    # The LPIPS VGG CKPT path, used for calculating the lpips loss
# OUTPUT_DIR=/PATH/output_dir    # The checkpoint saving dir
# IMAGE_ANNO=annotation/image_text.jsonl   # The image annotation file path
# VIDEO_ANNO=annotation/video_text.jsonl   # The video annotation file path
# RESOLUTION=256     # The training resolution, default is 256
# NUM_FRAMES=17     # x * 8 + 1, the number of video frames
# BATCH_SIZE=2  

export HIP_VISIBLE_DEVICES=4,5,6,7

GPUS=4  # The gpu number
VAE_MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/causal_video_vae   # The vae model dir
mashun1's avatar
mashun1 committed
22
LPIPS_CKPT=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/vgg_lpips.pth    # The LPIPS VGG CKPT path, used for calculating the lpips loss
mashun's avatar
mashun committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
OUTPUT_DIR=./temp_vae    # The checkpoint saving dir
IMAGE_ANNO=annotation/customs/vae_image.jsonl   # The image annotation file path
VIDEO_ANNO=annotation/customs/vae_video.jsonl   # The video annotation file path
RESOLUTION=256     # The training resolution, default is 256
NUM_FRAMES=9     # x * 8 + 1, the number of video frames
BATCH_SIZE=2 

# 当使用add_discriminator时,需要将disc_start设置为0,否则会报错
# Stage-1
torchrun --nproc_per_node $GPUS \
    train/train_video_vae.py \
    --num_workers 6 \
    --model_path $VAE_MODEL_PATH \
    --model_dtype bf16 \
    --lpips_ckpt $LPIPS_CKPT \
    --output_dir $OUTPUT_DIR \
    --image_anno $IMAGE_ANNO \
    --video_anno $VIDEO_ANNO \
    --use_image_video_mixed_training \
    --image_mix_ratio 0.1 \
    --resolution $RESOLUTION \
    --max_frames $NUM_FRAMES \
    --disc_start 0 \
    --kl_weight 1e-12 \
    --pixelloss_weight 10.0 \
    --perceptual_weight 1.0 \
    --disc_weight 0.5 \
    --batch_size $BATCH_SIZE \
    --opt adamw \
    --opt_betas 0.9 0.95 \
    --seed 42 \
    --weight_decay 1e-3 \
    --clip_grad 1.0 \
    --lr 1e-4 \
    --lr_disc 1e-4 \
    --warmup_epochs 0 \
    --epochs 10 \
    --iters_per_epoch 1000 \
    --print_freq 40 \
    --save_ckpt_freq 1 \
    --add_discriminator

# Stage-2
CONTEXT_SIZE=1  # context parallel size, GPUS % CONTEXT_SIZE == 0
NUM_FRAMES=18   # 17 * CONTEXT_SIZE + 1
VAE_CKPT_PATH=./temp_vae/checkpoint.pth   # The stage-1 trained ckpt

torchrun --nproc_per_node $GPUS \
    train/train_video_vae.py \
    --num_workers 6 \
    --model_path $VAE_MODEL_PATH \
    --model_dtype bf16 \
    --pretrained_vae_weight $VAE_CKPT_PATH \
    --use_context_parallel \
    --context_size $CONTEXT_SIZE \
    --lpips_ckpt $LPIPS_CKPT \
    --output_dir $OUTPUT_DIR \
    --video_anno $VIDEO_ANNO \
    --image_mix_ratio 0.0 \
    --resolution $RESOLUTION \
    --max_frames $NUM_FRAMES \
    --disc_start 0 \
    --kl_weight 1e-12 \
    --pixelloss_weight 10.0 \
    --perceptual_weight 1.0 \
    --disc_weight 0.5 \
    --batch_size $BATCH_SIZE \
    --opt adamw \
    --opt_betas 0.9 0.95 \
    --seed 42 \
    --weight_decay 1e-3 \
    --clip_grad 1.0 \
    --lr 1e-4 \
    --lr_disc 1e-4 \
    --warmup_epochs 1 \
    --epochs 10 \
    --iters_per_epoch 1000 \
    --print_freq 40 \
    --save_ckpt_freq 1 \
    --add_discriminator