_device.py 2.76 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
"""Flags for managing compute devices. Currently only contains TPU flags."""

from absl import flags
18
from absl import logging
19
20
21
22
23
24

from official.utils.flags._conventions import help_wrap


def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
Hongkun Yu's avatar
Hongkun Yu committed
25

26
27
28
29
30
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
Hongkun Yu's avatar
Hongkun Yu committed
31

32
33
34
35
36
37
38
39
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
40
        logging.error("%s must be a GCS path.", key)
41
42
43
44
45
46
47
        valid_flags = False

    return valid_flags


def define_device(tpu=True):
  """Register device specific flags.
Hongkun Yu's avatar
Hongkun Yu committed
48

49
50
  Args:
    tpu: Create flags to specify TPU operation.
Hongkun Yu's avatar
Hongkun Yu committed
51

52
53
54
55
56
57
58
59
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
60
61
        name="tpu",
        default=None,
62
63
64
65
66
67
68
69
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
70
71
        name="tpu_zone",
        default=None,
72
73
74
75
76
77
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
78
79
        name="tpu_gcp_project",
        default=None,
80
81
82
83
84
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

Hongkun Yu's avatar
Hongkun Yu committed
85
86
87
88
    flags.DEFINE_integer(
        name="num_tpu_shards",
        default=8,
        help=help_wrap("Number of shards (TPU chips)."))
89
90

  return key_flags