guidelines.md 1.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Using flags in official models

1. **All common flags must be incorporated in the models.**

   Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
   and channeled through `official.utils.flags.core`. For instance to define common supervised
   learning parameters one could use the following code:

   ```$xslt
   from absl import app as absl_app
   from absl import flags

   from official.utils.flags import core as flags_core


   def define_flags():
     flags_core.define_base()
     flags.adopt_key_flags(flags_core)


   def main(_):
     flags_obj = flags.FLAGS
     print(flags_obj)


   if __name__ == "__main__"
     absl_app.run(main)
   ```
2. **Validate flag values.**

   See the [Validators](#validators) section for implementation details.

   Validators in the official model repo should not access the file system, such as verifying
   that files exist, due to the strict ordering requirements.

3. **Flag values should not be mutated.**

   Instead of mutating flag values, use getter functions to return the desired values. An example
39
   getter function is `get_tf_dtype` function below:
40
41

   ```
42
   # Map string to TensorFlow dtype
43
   DTYPE_MAP = {
44
45
       "fp16": tf.float16,
       "fp32": tf.float32,
46
47
   }

48
49
50
51
52
53
   def get_tf_dtype(flags_obj):
     if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
       # If the graph_rewrite is used, we build the graph with fp32, and let the
       # graph rewrite change ops to fp16.
       return tf.float32
     return DTYPE_MAP[flags_obj.dtype]
54
55
56
57
58
59


   def main(_):
     flags_obj = flags.FLAGS()

     # Do not mutate flags_obj
60
61
     # if flags_obj.fp16_implementation == "graph_rewrite":
     #   flags_obj.dtype = "float32" # Don't do this
62

63
     print(get_tf_dtype(flags_obj))
64
65
     ...
   ```