det_r18_vd_ct.yml 2.43 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
Global:
  use_gpu: true
  epoch_num: 600
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/det_ct/
  save_epoch_step: 10
  # evaluation is run every 2000 iterations
  eval_batch_step: [0,1000]
  cal_metric_during_train: False
  pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img623.jpg
  save_res_path: ./output/det_ct/predicts_ct.txt

Architecture:
  model_type: det
  algorithm: CT
  Transform:
  Backbone:
    name: ResNet_vd
    layers: 18
  Neck:
    name: CTFPN
  Head:
    name: CT_Head
    in_channels: 512
    hidden_dim: 128
    num_classes: 3

Loss:
  name: CTLoss

Optimizer:
  name: Adam
  lr:  #PolynomialDecay
    name: Linear 
    learning_rate: 0.001
    end_lr: 0.
    epochs: 600
    step_each_epoch: 1254
    power: 0.9

PostProcess:
  name: CTPostProcess
  box_type: poly

Metric:
  name: CTMetric
  main_indicator: f_score

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/total_text/train
    label_file_list:
      - ./train_data/total_text/train/train.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage:
          img_mode: RGB
          channel_first: False
      - CTLabelEncode: # Class handling label
      - RandomScale:
      - MakeShrink:
      - GroupRandomHorizontalFlip:
      - GroupRandomRotate:
      - GroupRandomCropPadding:
      - MakeCentripetalShift:
      - ColorJitter:
          brightness: 0.125
          saturation: 0.5 
      - ToCHWImage: 
      - NormalizeImage:
      - KeepKeys:
          keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: True
    batch_size_per_card: 4
    num_workers: 8

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/total_text/test
    label_file_list:
      - ./train_data/total_text/test/test.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage:
          img_mode: RGB
          channel_first: False
      - CTLabelEncode: # Class handling label
      - ScaleAlignedShort:
      - NormalizeImage:
          order: 'hwc'
      - ToCHWImage: 
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list          
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1
    num_workers: 2