"official/legacy/bert/model_training_utils.py" did not exist on "ebc28058dc268be3d4b9704a67899cbc769cb69a"
train_hifacegan.yml 3.46 KB
Newer Older
mashun1's avatar
anytext  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# general settings
name: HiFaceGAN_SR4x_train
model_type: HiFaceGANModel
scale: 1    # HiFaceGAN does not resize lq input
num_gpu: 1  # set num_gpu: 0 for cpu mode
manual_seed: 0

datasets:
  train:
    name: FFHQ_sr4x
    type: PairedImageDataset
    dataroot_gt: datasets/FFHQ_512_gt
    dataroot_lq: datasets/FFHQ_512_lq_sr4x
    # (for lmdb)
    # dataroot_gt: datasets/FFHQ_512_gt.lmdb
    # dataroot_lq: datasets/FFHQ_512_lq_sr4x.lmdb
    io_backend:
      type: disk
      # (for lmdb)
      # type: lmdb

    gt_size: 512
    use_hflip: true
    use_rot: false

    # data loader
    num_worker_per_gpu: 1
    batch_size_per_gpu: 1
    dataset_enlarge_ratio: 1
    prefetch_mode: ~

  val:
    name: FFHQ_sr4x_val   # For now, use training dataset for validation
    type: PairedImageDataset
    dataroot_gt: datasets/FFHQ_512_gt
    dataroot_lq: datasets/FFHQ_512_lq_sr4x
    io_backend:
      type: disk

# network structures
network_g:
  type: HiFaceGAN
  num_in_ch: 3
  num_feat: 48
  use_vae: false
  z_dim: 256  # dummy var
  crop_size: 512
  #norm_g: 'spectralspadesyncbatch3x3'
  #norm_g: 'spectralspadeinstance3x3'
  norm_g: 'spectralspadebatch3x3'  # 20210519: Use batchnorm for now.
  is_train: false                  # HifaceGAN supports progressive training
                                   # so network architecture depends on it

network_d:
  type: HiFaceGANDiscriminator
  num_in_ch: 3
  num_out_ch: 3
  conditional_d: True
  num_feat: 64
  norm_d: 'spectralinstance'


# training settings
train:
  optim_g:
    type: Adam
    lr: !!float 2e-4
    weight_decay: 0
    betas: [0.9, 0.999]
  optim_d:
    type: Adam
    lr: !!float 2e-4
    weight_decay: 0
    betas: [0.9, 0.999]

  scheduler:
    # TODO: Integrate the exact scheduling system of HiFaceGAN
    # which involves a fixed lr followed by a linear decay
    # It is not supported in current BasicSR project.
    type: MultiStepLR
    milestones: [500, 1000, 2000, 3000]
    gamma: 0.5

  # By default HiFaceGAN trains for 50 epochs, need auto conversion with
  #    total_iter = #(epochs) * #(dataset_size) * enlarge_ratio / batch_size
  total_iter: 5000
  warmup_iter: -1  # no warm up

  # losses:
  #   Note: HiFaceGAN does not need pixel loss, use it at your own discretion
  # pixel_opt:
  #  type: L1Loss
  #  loss_weight: !!float 1e-2
  #  reduction: mean
  perceptual_opt:
    type: PerceptualLoss
    # vgg_layer_indices: 2,7,12,21,30
    # weights: 1/32, 1/16, 1/8, 1/4, 1
    layer_weights:
      'relu1_1': !!float 3.125e-2
      'relu2_1': !!float 6.25e-2
      'relu3_1': !!float 0.125
      'relu4_1': !!float 0.25
      'relu5_1': !!float 1.0
    vgg_type: vgg19
    use_input_norm: false # keep in [0,1] range
    range_norm: false
    perceptual_weight: !!float 10.0
    style_weight: 0
    criterion: l1
  gan_opt:
    type: MultiScaleGANLoss
    gan_type: lsgan
    real_label_val: 1.0
    fake_label_val: 0.0
    loss_weight: !!float 1.0
  feature_matching_opt:
    type: GANFeatLoss
    loss_weight: !!float 10.0
    criterion: l1

  net_d_iters: 1
  net_d_init_iters: 0

# path
path:
  pretrain_network_g: ~ # experiments/pretrained_models/4xsr/latest_net_G.pth
  strict_load_g: true

# validation settings
val:
  val_freq: !!float 5e3
  save_img: true

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~