RIDCP.yml 2.8 KB
Newer Older
mashun1's avatar
ridcp  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# general settings
name: vq_weight_dehaze_trained_on_ours
model_type: VQDehazeModel 
scale: &upscale 1
num_gpu: 4 # set num_gpu: 0 for cpu mode
manual_seed: 0

# dataset and data loader settings
datasets:
  train:
    name: General_Image_Train
    type: HazeOnlineDataset  
    dataroot_gt: datasets/rgb_500
    dataroot_depth: datasets/depth_500
    beta_range: [0.3, 1.5]
    A_range: [0.25, 1.0]
    color_p: 1.0
    color_range: [-0.025, 0.025]
    io_backend:
      type: disk
 
    gt_size: 256
    use_resize_crop: true
    use_flip: true
    use_rot: false

    # data loader
    use_shuffle: true
    batch_size_per_gpu: &bsz 4 
    num_worker_per_gpu: *bsz 
    dataset_enlarge_ratio: 1

    prefetch_mode: cpu
    num_prefetch_queue: *bsz

  val:
    name: General_Image_Train
    type: HazeOnlineDataset  
    dataroot_gt: datasets/rgb_500
    dataroot_depth: datasets/depth_500
    beta_range: [0.3, 1.5]
    A_range: [0.25, 1.0]
    color_p: 1.0
    color_range: [-0.025, 0.025]
    io_backend:
      type: disk

# network structures
network_g:
  type: VQWeightDehazeNet 
  gt_resolution: 256
  norm_type: 'gn'
  act_type: 'silu'
  scale_factor: *upscale
  codebook_params:
    - [64, 1024, 512]

  LQ_stage: true
  use_weight: false
  weight_alpha: -1.0
  frozen_module_keywords: ['quantize', 'decoder_group', 'after_quant_group', 'out_conv']

network_d:
  type: UNetDiscriminatorSN 
  num_in_ch: 512

# path
path:
  pretrain_network_hq: pretrained_models/pretrained_HQPs.pth
  pretrain_network_g: 
  pretrain_network_d: ~
  strict_load: false 
  # resume_state: ~

# training settings
train:
  optim_g:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]
  optim_d:
    type: Adam
    lr: !!float 4e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: MultiStepLR
    milestones: [5000, 10000, 15000, 20000, 250000, 300000, 350000]
    gamma: 1 

  total_iter: 45000
  warmup_iter: -1  # no warm up

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1.0 
    reduction: mean

  perceptual_opt:
    type: LPIPSLoss
    loss_weight: !!float 1.0 
    
  gan_opt:
    type: GANLoss
    gan_type: hinge
    real_label_val: 1.0
    fake_label_val: 0.0
    loss_weight: 0.1

  codebook_opt:
    loss_weight: 1.0
  
  semantic_opt:
    loss_weight: 0.1 


  net_d_iters: 1
  net_d_init_iters: !!float 0 

# validation settings·
val:
  val_freq: !!float 80000
  save_img: true

  key_metric: psnr
  metrics:
    psnr: # metric name, can be arbitrary
      type: psnr 
      crop_border: 4
      test_y_channel: true
    ssim:
      type: ssim 
      crop_border: 4
      test_y_channel: true
    lpips:
      type: lpips 
      better: lower

# logging settings
logger:
  print_freq: 10
  save_checkpoint_freq: !!float 1e3
  save_latest_freq: !!float 5e2
  show_tf_imgs_freq: !!float 1e2
  use_tb_logger: true