train.proto 7.73 KB
Newer Older
1
2
3
4
5
6
7
syntax = "proto2";

package object_detection.protos;

import "object_detection/protos/optimizer.proto";
import "object_detection/protos/preprocessor.proto";

8
9
10
11
12
13
14
15

enum CheckpointVersion {
  UNKNOWN  = 0;
  V1 = 1;
  V2 = 2;
}


16
// Message for configuring DetectionModel training jobs (train.py).
17
// Next id: 31
18
message TrainConfig {
19
20
21
  // Effective batch size to use for training.
  // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
  // `batch_size` / number of cores (or `batch_size` / number of GPUs).
22
23
24
25
26
27
28
29
30
  optional uint32 batch_size = 1 [default=32];

  // Data augmentation options.
  repeated PreprocessingStep data_augmentation_options = 2;

  // Whether to synchronize replicas during training.
  optional bool sync_replicas = 3 [default=false];

  // How frequently to keep checkpoints.
31
  optional float keep_checkpoint_every_n_hours = 4 [default=10000.0];
32
33
34
35
36
37
38
39
40
41
42

  // Optimizer used to train the DetectionModel.
  optional Optimizer optimizer = 5;

  // If greater than 0, clips gradients by this value.
  optional float gradient_clipping_by_norm = 6 [default=0.0];

  // Checkpoint to restore variables from. Typically used to load feature
  // extractor variables trained outside of object detection.
  optional string fine_tune_checkpoint = 7 [default=""];

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  // This option controls how variables are restored from the (pre-trained)
  // fine_tune_checkpoint. For TF2 models, 3 different types are supported:
  //   1. "classification": Restores only the classification backbone part of
  //        the feature extractor. This option is typically used when you want
  //        to train a detection model starting from a pre-trained image
  //        classification model, e.g. a ResNet model pre-trained on ImageNet.
  //   2. "detection": Restores the entire feature extractor. The only parts
  //        of the full detection model that are not restored are the box and
  //        class prediction heads. This option is typically used when you want
  //        to use a pre-trained detection model and train on a new dataset or
  //        task which requires different box and class prediction heads.
  //   3. "full": Restores the entire detection model, including the
  //        feature extractor, its classification backbone, and the prediction
  //        heads. This option should only be used when the pre-training and
  //        fine-tuning tasks are the same. Otherwise, the model's parameters
  //        may have incompatible shapes, which will cause errors when
  //        attempting to restore the checkpoint.
  // For more details about this parameter, see the restore_map (TF1) or
61
  // restore_from_object (TF2) function documentation in the
62
  // /meta_architectures/*meta_arch.py files.
63
64
  optional string fine_tune_checkpoint_type = 22 [default=""];

65
66
67
68
69
  // Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
  // v1 style of restoring checkpoints. If v2, uses the eager mode checkpoint
  // restoration API.
  optional CheckpointVersion fine_tune_checkpoint_version = 28 [default=V1];

70
  // [Deprecated]: use fine_tune_checkpoint_type instead.
71
72
73
74
  // Specifies if the finetune checkpoint is from an object detection model.
  // If from an object detection model, the model being trained should have
  // the same parameters with the exception of the num_classes parameter.
  // If false, it assumes the checkpoint was a object classification model.
75
  optional bool from_detection_checkpoint = 8 [default=false, deprecated=true];
76

77
78
  // Whether to load all checkpoint vars that match model variable names and
  // sizes. This option is only available if `from_detection_checkpoint` is
79
  // True.  This option is *not* supported for TF2 --- setting it to true
80
  // will raise an error. Instead, set fine_tune_checkpoint_type: 'full'.
81
82
  optional bool load_all_detection_checkpoint_vars = 19 [default = false];

83
84
85
86
87
88
89
90
91
92
93
94
  // Whether to run dummy computation when loading a `fine_tune_checkpoint`.
  // This option is true by default since it is often necessary to run the model
  // on a dummy input before loading a `fine_tune_checkpoint`, in order to
  // ensure that all the model variables have alread been built successfully.
  // Some meta architectures, like CenterNet, do not require dummy computation
  // to successfully load all checkpoint variables, and in these cases this
  // flag may be set to false to reduce startup time and memory consumption.
  // Note, this flag only affects dummy computation when loading a
  // `fine_tune_checkpoint`, e.g. it does not affect the dummy computation that
  // is run when creating shadow copies of model variables when using EMA.
  optional bool run_fine_tune_checkpoint_dummy_computation = 30 [default=true];

95
96
97
98
99
100
101
102
103
104
105
106
  // Number of steps to train the DetectionModel for. If 0, will train the model
  // indefinitely.
  optional uint32 num_steps = 9 [default=0];

  // Number of training steps between replica startup.
  // This flag must be set to 0 if sync_replicas is set to true.
  optional float startup_delay_steps = 10 [default=15];

  // If greater than 0, multiplies the gradient of bias variables by this
  // amount.
  optional float bias_grad_multiplier = 11 [default=0];

107
108
109
110
111
112
113
  // Variables that should be updated during training. Note that variables which
  // also match the patterns in freeze_variables will be excluded.
  repeated string update_trainable_variables = 25;

  // Variables that should not be updated during training. If
  // update_trainable_variables is not empty, only eliminates the included
  // variables according to freeze_variables patterns.
114
115
116
117
118
119
  repeated string freeze_variables = 12;

  // Number of replicas to aggregate before making parameter updates.
  optional int32 replicas_to_aggregate = 13 [default=1];

  // Maximum number of elements to store within a queue.
120
  optional int32 batch_queue_capacity = 14 [default=150, deprecated=true];
121
122

  // Number of threads to use for batching.
123
  optional int32 num_batch_queue_threads = 15 [default=8, deprecated=true];
124
125

  // Maximum capacity of the queue used to prefetch assembled batches.
126
  optional int32 prefetch_queue_capacity = 16 [default=5, deprecated=true];
Vivek Rathod's avatar
Vivek Rathod committed
127
128
129
130
131

  // If true, boxes with the same coordinates will be merged together.
  // This is useful when each box can have multiple labels.
  // Note that only Sigmoid classification losses should be used.
  optional bool merge_multiple_label_boxes = 17 [default=false];
132

133
134
135
136
  // If true, will use multiclass scores from object annotations as ground
  // truth. Currently only compatible with annotated image inputs.
  optional bool use_multiclass_scores = 24 [default = false];

137
138
139
140
141
142
  // Whether to add regularization loss to `total_loss`. This is true by
  // default and adds all regularization losses defined in the model to
  // `total_loss`.
  // Setting this option to false is very useful while debugging the model and
  // losses.
  optional bool add_regularization_loss = 18 [default=true];
143
144
145
146
147

  // Maximum number of boxes used during training.
  // Set this to at least the maximum amount of boxes in the input data.
  // Otherwise, it may cause "Data loss: Attempted to pad to a smaller size
  // than the input element" errors.
148
  optional int32 max_number_of_boxes = 20 [default=100, deprecated=true];
149
150
151
152

  // Whether to remove padding along `num_boxes` dimension of the groundtruth
  // tensors.
  optional bool unpad_groundtruth_tensors = 21 [default=true];
153
154
155
156
157

  // Whether to retain original images (i.e. not pre-processed) in the tensor
  // dictionary, so that they can be displayed in Tensorboard. Note that this
  // will lead to a larger memory footprint.
  optional bool retain_original_images = 23 [default=false];
158

159
160
  // Whether to use bfloat16 for training. This is currently only supported for
  // TPUs.
161
  optional bool use_bfloat16 = 26 [default=false];
162
163
164

  // Whether to summarize gradients.
  optional bool summarize_gradients = 27 [default=false];
165

166
}
167