train_causal_video_vae.sh 3.39 KB
Newer Older
mashun's avatar
mashun committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/bin/bash

# This script is used for Causal VAE Training
# It undergoes a two-stage training
# Stage-1: image and video mixed training
# Stage-2: pure video training, using context parallel to load video with more video frames (up to 257 frames)

# GPUS=8  # The gpu number
# VAE_MODEL_PATH=PATH/vae_ckpt   # The vae model dir
# LPIPS_CKPT=vgg_lpips.pth    # The LPIPS VGG CKPT path, used for calculating the lpips loss
# OUTPUT_DIR=/PATH/output_dir    # The checkpoint saving dir
# IMAGE_ANNO=annotation/image_text.jsonl   # The image annotation file path
# VIDEO_ANNO=annotation/video_text.jsonl   # The video annotation file path
# RESOLUTION=256     # The training resolution, default is 256
# NUM_FRAMES=17     # x * 8 + 1, the number of video frames
# BATCH_SIZE=2  

export HIP_VISIBLE_DEVICES=4,5,6,7

GPUS=4  # The gpu number
VAE_MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/causal_video_vae   # The vae model dir
LPIPS_CKPT=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/vgg.pth    # The LPIPS VGG CKPT path, used for calculating the lpips loss
OUTPUT_DIR=./temp_vae    # The checkpoint saving dir
IMAGE_ANNO=annotation/customs/vae_image.jsonl   # The image annotation file path
VIDEO_ANNO=annotation/customs/vae_video.jsonl   # The video annotation file path
RESOLUTION=256     # The training resolution, default is 256
NUM_FRAMES=9     # x * 8 + 1, the number of video frames
BATCH_SIZE=2 

# 当使用add_discriminator时,需要将disc_start设置为0,否则会报错
# Stage-1
torchrun --nproc_per_node $GPUS \
    train/train_video_vae.py \
    --num_workers 6 \
    --model_path $VAE_MODEL_PATH \
    --model_dtype bf16 \
    --lpips_ckpt $LPIPS_CKPT \
    --output_dir $OUTPUT_DIR \
    --image_anno $IMAGE_ANNO \
    --video_anno $VIDEO_ANNO \
    --use_image_video_mixed_training \
    --image_mix_ratio 0.1 \
    --resolution $RESOLUTION \
    --max_frames $NUM_FRAMES \
    --disc_start 0 \
    --kl_weight 1e-12 \
    --pixelloss_weight 10.0 \
    --perceptual_weight 1.0 \
    --disc_weight 0.5 \
    --batch_size $BATCH_SIZE \
    --opt adamw \
    --opt_betas 0.9 0.95 \
    --seed 42 \
    --weight_decay 1e-3 \
    --clip_grad 1.0 \
    --lr 1e-4 \
    --lr_disc 1e-4 \
    --warmup_epochs 0 \
    --epochs 10 \
    --iters_per_epoch 1000 \
    --print_freq 40 \
    --save_ckpt_freq 1 \
    --add_discriminator

# Stage-2
CONTEXT_SIZE=1  # context parallel size, GPUS % CONTEXT_SIZE == 0
NUM_FRAMES=18   # 17 * CONTEXT_SIZE + 1
VAE_CKPT_PATH=./temp_vae/checkpoint.pth   # The stage-1 trained ckpt

torchrun --nproc_per_node $GPUS \
    train/train_video_vae.py \
    --num_workers 6 \
    --model_path $VAE_MODEL_PATH \
    --model_dtype bf16 \
    --pretrained_vae_weight $VAE_CKPT_PATH \
    --use_context_parallel \
    --context_size $CONTEXT_SIZE \
    --lpips_ckpt $LPIPS_CKPT \
    --output_dir $OUTPUT_DIR \
    --video_anno $VIDEO_ANNO \
    --image_mix_ratio 0.0 \
    --resolution $RESOLUTION \
    --max_frames $NUM_FRAMES \
    --disc_start 0 \
    --kl_weight 1e-12 \
    --pixelloss_weight 10.0 \
    --perceptual_weight 1.0 \
    --disc_weight 0.5 \
    --batch_size $BATCH_SIZE \
    --opt adamw \
    --opt_betas 0.9 0.95 \
    --seed 42 \
    --weight_decay 1e-3 \
    --clip_grad 1.0 \
    --lr 1e-4 \
    --lr_disc 1e-4 \
    --warmup_epochs 1 \
    --epochs 10 \
    --iters_per_epoch 1000 \
    --print_freq 40 \
    --save_ckpt_freq 1 \
    --add_discriminator