"src/vscode:/vscode.git/clone" did not exist on "4fb3fd4a23c0eaad0d9b1324a4cd6e89495d7532"
guidelines.md 2.38 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Using flags in official models

1. **All common flags must be incorporated in the models.**

   Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
   and channeled through `official.utils.flags.core`. For instance to define common supervised
   learning parameters one could use the following code:

   ```$xslt
   from absl import app as absl_app
   from absl import flags

   from official.utils.flags import core as flags_core


   def define_flags():
     flags_core.define_base()
     flags.adopt_key_flags(flags_core)


   def main(_):
     flags_obj = flags.FLAGS
     print(flags_obj)


   if __name__ == "__main__"
     absl_app.run(main)
   ```
2. **Validate flag values.**

   See the [Validators](#validators) section for implementation details.

   Validators in the official model repo should not access the file system, such as verifying
   that files exist, due to the strict ordering requirements.

3. **Flag values should not be mutated.**

   Instead of mutating flag values, use getter functions to return the desired values. An example
53
   getter function is `get_tf_dtype` function below:
54
55

   ```
56
   # Map string to TensorFlow dtype
57
   DTYPE_MAP = {
58
59
       "fp16": tf.float16,
       "fp32": tf.float32,
60
61
   }

62
63
64
65
66
67
   def get_tf_dtype(flags_obj):
     if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
       # If the graph_rewrite is used, we build the graph with fp32, and let the
       # graph rewrite change ops to fp16.
       return tf.float32
     return DTYPE_MAP[flags_obj.dtype]
68
69
70
71
72
73


   def main(_):
     flags_obj = flags.FLAGS()

     # Do not mutate flags_obj
74
75
     # if flags_obj.fp16_implementation == "graph_rewrite":
     #   flags_obj.dtype = "float32" # Don't do this
76

77
     print(get_tf_dtype(flags_obj))
78
79
     ...
   ```