ft.yml 1.87 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
name: ft

seed: 2233
device_specific_seed: true
workder_specific_seed: true

data:
  data_path: data_configs/train/example/mix.yml
  use_chat_template: true
  maximum_text_tokens: 888
  prompt_dropout_prob: !!float 0.0001
  ref_img_dropout_prob: !!float 0.5
  max_output_pixels: 1048576 # 1024 * 1024
  max_input_pixels: [1048576, 1048576, 589824, 262144] # [1024 * 1024, 1024 * 1024, 768 * 768, 512 * 512]
  max_side_length: 2048
  
model:
  pretrained_vae_model_name_or_path: black-forest-labs/FLUX.1-dev
  pretrained_text_encoder_model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
  pretrained_model_path: OmniGen2/OmniGen2
  
  arch_opt:
    patch_size: 2
    in_channels: 16
    hidden_size: 2520
    num_layers: 32
    num_refiner_layers: 2
    num_attention_heads: 21
    num_kv_heads: 7
    multiple_of: 256
    norm_eps: !!float 1e-05
    axes_dim_rope: [40, 40, 40]
    axes_lens: [10000, 10000, 10000]
    text_feat_dim: 2048
    timestep_scale: !!float 1000

transport:
  snr_type: lognorm
  do_shift: true
  dynamic_time_shift: true

train:
  # Dataloader
  global_batch_size: 16
  batch_size: 2
  gradient_accumulation_steps: 1

  max_train_steps: 4000
  
  dataloader_num_workers: 6

  # Optimizer
  learning_rate: !!float 8e-7
  scale_lr: false
  lr_scheduler: timm_constant_with_warmup
  warmup_t: 500
  warmup_lr_init: 1e-18
  warmup_prefix: true
  t_in_epochs: false

  # resume_from_checkpoint: 

  use_8bit_adam: false
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_weight_decay: !!float 0.01
  adam_epsilon: !!float 1e-08
  max_grad_norm: 1

  gradient_checkpointing: true
  
  set_grads_to_none: true

  # Misc
  allow_tf32: false
  mixed_precision: 'bf16'

  ema_decay: 0.0

val:
  validation_steps: 500
  train_visualization_steps: 100

logger:
  log_with: [wandb, tensorboard]
  # log_with: ~

  checkpointing_steps: 1000
  checkpoints_total_limit: ~

cache_dir: 
resume_from_checkpoint: latest