factory.py 1.65 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
16
"""Factory to provide model configs."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
17
from official.modeling.hyperparams import params_dict
Yeqing Li's avatar
Yeqing Li committed
18
from official.vision.detection.configs import maskrcnn_config
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
from official.vision.detection.configs import olnmask_config
20
from official.vision.detection.configs import retinanet_config
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
21
from official.vision.detection.configs import shapemask_config
22
23
24
25
26
27
28


def config_generator(model):
  """Model function generator."""
  if model == 'retinanet':
    default_config = retinanet_config.RETINANET_CFG
    restrictions = retinanet_config.RETINANET_RESTRICTIONS
Yeqing Li's avatar
Yeqing Li committed
29
30
31
  elif model == 'mask_rcnn':
    default_config = maskrcnn_config.MASKRCNN_CFG
    restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
32
33
34
  elif model == 'olnmask':
    default_config = olnmask_config.OLNMASK_CFG
    restrictions = olnmask_config.OLNMASK_RESTRICTIONS
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
36
37
  elif model == 'shapemask':
    default_config = shapemask_config.SHAPEMASK_CFG
    restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
38
39
40
41
  else:
    raise ValueError('Model %s is not supported.' % model)

  return params_dict.ParamsDict(default_config, restrictions)