train.py 2.95 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
vishnubanna's avatar
vishnubanna committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14
15

# Lint as: python3
vishnubanna's avatar
vishnubanna committed
16
17
18
19
20
21
22
23
24
25
"""TensorFlow Model Garden Vision training driver."""

from absl import app
from absl import flags
import gin

from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
26
from official.core import train_utils
vishnubanna's avatar
vishnubanna committed
27
from official.modeling import performance
28
from official.vision.beta.projects.yolo.common import registry_imports  # pylint: disable=unused-import
vishnubanna's avatar
vishnubanna committed
29
30
31

FLAGS = flags.FLAGS

32
'''
33
python3 -m official.vision.beta.projects.yolo.train --mode=train_and_eval --experiment=darknet_classification --model_dir=training_dir --config_file=official/vision/beta/projects/yolo/configs/experiments/darknet53_tfds.yaml
Vishnu Banna's avatar
Vishnu Banna committed
34
35
36


python3.8 -m official.vision.beta.projects.yolo.train --experiment=yolo_darknet --mode train_and_eval --config_file yolo/configs/experiments/yolov4/inference/512-swin.yaml --model_dir ../checkpoints/test-swin
37
38
'''

39

vishnubanna's avatar
vishnubanna committed
40
41
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
42
  print(FLAGS.experiment)
vishnubanna's avatar
vishnubanna committed
43
  params = train_utils.parse_configuration(FLAGS)
44

vishnubanna's avatar
vishnubanna committed
45
46
47
48
49
50
51
52
53
54
55
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
56
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
vishnubanna's avatar
vishnubanna committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)
  with distribution_strategy.scope():
    task = task_factory.get_task(params.task, logging_dir=model_dir)

  train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      model_dir=model_dir)

Le Hou's avatar
Le Hou committed
72
73
  train_utils.save_gin_config(FLAGS.mode, model_dir)

vishnubanna's avatar
vishnubanna committed
74
75
76
if __name__ == '__main__':
  tfm_flags.define_flags()
  app.run(main)