registry.py 4.19 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

15
"""Registry utility."""
Yeqing Li's avatar
Yeqing Li committed
16
from absl import logging
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


def register(registered_collection, reg_key):
  """Register decorated function or class to collection.

  Register decorated function or class into registered_collection, in a
  hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
  the decorated function or class is stored under
  registered_collection["my_model"]["my_exp"]["my_config_0"].
  This decorator is supposed to be used together with the lookup() function in
  this file.

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      put into this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    A decorator function
  Raises:
    KeyError: when function or class to register already exists.
  """
  def decorator(fn_or_cls):
    """Put fn_or_cls in the dictionary."""
    if isinstance(reg_key, str):
      hierarchy = reg_key.split("/")
      collection = registered_collection
      for h_idx, entry_name in enumerate(hierarchy[:-1]):
        if entry_name not in collection:
          collection[entry_name] = {}
        collection = collection[entry_name]
        if not isinstance(collection, dict):
          raise KeyError(
              "Collection path {} at position {} already registered as "
              "a function or class.".format(entry_name, h_idx))
      leaf_reg_key = hierarchy[-1]
    else:
      collection = registered_collection
      leaf_reg_key = reg_key

    if leaf_reg_key in collection:
Yeqing Li's avatar
Yeqing Li committed
58
59
60
61
62
63
64
65
66
67
      if "beta" in fn_or_cls.__module__:
        # TODO(yeqing): Clean this temporary branch for beta.
        logging.warn(
            "Duplicate registeration of beta module "
            "name %r new %r old %r", reg_key, collection[leaf_reg_key],
            fn_or_cls.__module__)
        return
      else:
        raise KeyError("Function or class {} registered multiple times.".format(
            leaf_reg_key))
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    collection[leaf_reg_key] = fn_or_cls
    return fn_or_cls
  return decorator


def lookup(registered_collection, reg_key):
  """Lookup and return decorated function or class in the collection.

  Lookup decorated function or class in registered_collection, in a
  hierarchical order. For example, when
  reg_key="my_model/my_exp/my_config_0",
  this function will return
  registered_collection["my_model"]["my_exp"]["my_config_0"].

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      retrieved from this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    The registered function or class.
  Raises:
    LookupError: when reg_key cannot be found.
  """
  if isinstance(reg_key, str):
    hierarchy = reg_key.split("/")
    collection = registered_collection
    for h_idx, entry_name in enumerate(hierarchy):
      if entry_name not in collection:
        raise LookupError(
99
100
101
            f"collection path {entry_name} at position {h_idx} is never "
            f"registered. Please make sure the {entry_name} and its library is "
            "imported and linked to the trainer binary.")
102
103
104
105
      collection = collection[entry_name]
    return collection
  else:
    if reg_key not in registered_collection:
106
107
108
109
      raise LookupError(
          f"registration key {reg_key} is never "
          f"registered. Please make sure the {reg_key} and its library is "
          "imported and linked to the trainer binary.")
110
    return registered_collection[reg_key]