tf_configure.bzl 7.86 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""Setup TensorFlow as external dependency"""

_TF_HEADER_DIR = "TF_HEADER_DIR"

_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"

_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"

_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"

_TF_CPLUSPLUS_VER = "TF_CPLUSPLUS_VER"

def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
    if not out:
        out = tpl
    repository_ctx.template(
        out,
        Label("//build_deps/tf_dependency:%s.tpl" % tpl),
        substitutions,
    )

def _fail(msg):
    """Output failure message when auto configuration fails."""
    red = "\033[0;31m"
    no_color = "\033[0m"
    fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))

def _is_windows(repository_ctx):
    """Returns true if the host operating system is windows."""
    os_name = repository_ctx.os.name.lower()
    if os_name.find("windows") != -1:
        return True
    return False

def _execute(
        repository_ctx,
        cmdline,
        error_msg = None,
        error_details = None,
        empty_stdout_fine = False):
    """Executes an arbitrary shell command.

    Helper for executes an arbitrary shell command.

    Args:
      repository_ctx: the repository_ctx object.
      cmdline: list of strings, the command to execute.
      error_msg: string, a summary of the error if the command fails.
      error_details: string, details about the error or steps to fix it.
      empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
        it's an error.

    Returns:
      The result of repository_ctx.execute(cmdline).
    """
    result = repository_ctx.execute(cmdline)
    if result.stderr or not (empty_stdout_fine or result.stdout):
        _fail("\n".join([
            error_msg.strip() if error_msg else "Repository command failed",
            result.stderr.strip(),
            error_details if error_details else "",
        ]))
    return result

def _read_dir(repository_ctx, src_dir):
    """Returns a string with all files in a directory.

    Finds all files inside a directory, traversing subfolders and following
    symlinks. The returned string contains the full path of all files
    separated by line breaks.

    Args:
        repository_ctx: the repository_ctx object.
        src_dir: directory to find files from.

    Returns:
        A string of all files inside the given dir.
    """
    if _is_windows(repository_ctx):
        src_dir = src_dir.replace("/", "\\")
        find_result = _execute(
            repository_ctx,
            ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
            empty_stdout_fine = True,
        )

        # src_files will be used in genrule.outs where the paths must
        # use forward slashes.
        result = find_result.stdout.replace("\\", "/")
    else:
        find_result = _execute(
            repository_ctx,
            ["find", src_dir, "-follow", "-type", "f"],
            empty_stdout_fine = True,
        )
        result = find_result.stdout
    return result

def _genrule(genrule_name, command, outs):
    """Returns a string with a genrule.

    Genrule executes the given command and produces the given outputs.

    Args:
        genrule_name: A unique name for genrule target.
        command: The command to run.
        outs: A list of files generated by this rule.

    Returns:
        A genrule target.
    """
    return (
        "genrule(\n" +
        '    name = "' +
        genrule_name + '",\n' +
        "    outs = [\n" +
        outs +
        "\n    ],\n" +
        '    cmd = """\n' +
        command +
        '\n   """,\n' +
        ")\n"
    )

def _norm_path(path):
    """Returns a path with '/' and remove the trailing slash."""
    path = path.replace("\\", "/")
    if path[-1] == "/":
        path = path[:-1]
    return path

def _symlink_genrule_for_dir(
        repository_ctx,
        src_dir,
        dest_dir,
        genrule_name,
        src_files = [],
        dest_files = [],
        tf_pip_dir_rename_pair = []):
    """Returns a genrule to symlink(or copy if on Windows) a set of files.
    If src_dir is passed, files will be read from the given directory; otherwise
    we assume files are in src_files and dest_files.
    Args:
        repository_ctx: the repository_ctx object.
        src_dir: source directory.
        dest_dir: directory to create symlink in.
        genrule_name: genrule name.
        src_files: list of source files instead of src_dir.
        dest_files: list of corresonding destination files.
        tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
          replace. For example, in TF pip package, the source code is under
          "tensorflow_core", and we might want to replace it with
          "tensorflow" to match the header includes.
    Returns:
        genrule target that creates the symlinks.
    """

    # Check that tf_pip_dir_rename_pair has the right length
    tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
    if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
        _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)

    if src_dir != None:
        src_dir = _norm_path(src_dir)
        dest_dir = _norm_path(dest_dir)
        files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))

        # Create a list with the src_dir stripped to use for outputs.
        if tf_pip_dir_rename_pair_len:
            dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
        else:
            dest_files = files.replace(src_dir, "").splitlines()
        src_files = files.splitlines()
    command = []
    outs = []

    for i in range(len(dest_files)):
        if dest_files[i] != "":
            # If we have only one file to link we do not want to use the dest_dir, as
            # $(@D) will include the full path to the file.
            dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]

            # Copy the headers to create a sandboxable setup.
            cmd = "cp -f"
            command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
            outs.append('        "' + dest_dir + dest_files[i] + '",')

    genrule = _genrule(
        genrule_name,
        ";\n".join(command),
        "\n".join(outs),
    )
    return genrule

def _tf_pip_impl(repository_ctx):
    tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
    tf_header_rule = _symlink_genrule_for_dir(
        repository_ctx,
        tf_header_dir,
        "include",
        "tf_header_include",
        tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
    )

    tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
    tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
    tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
    tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
    tf_cplusplus_ver = "-std=%s" % repository_ctx.os.environ[_TF_CPLUSPLUS_VER]

    tf_shared_library_rule = _symlink_genrule_for_dir(
        repository_ctx,
        None,
        "",
        tf_shared_library_name,
        [tf_shared_library_path],
        [tf_shared_library_name],
    )

    _tpl(repository_ctx, "BUILD", {
        "%{TF_HEADER_GENRULE}": tf_header_rule,
        "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
        "%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name,
    })

    _tpl(
        repository_ctx,
        "build_defs.bzl",
        {
            "%{tf_cx11_abi}": tf_cx11_abi,
            "%{tf_cplusplus_ver}": tf_cplusplus_ver,
        },
    )

tf_configure = repository_rule(
    environ = [
        _TF_HEADER_DIR,
        _TF_SHARED_LIBRARY_DIR,
        _TF_SHARED_LIBRARY_NAME,
        _TF_CXX11_ABI_FLAG,
        _TF_CPLUSPLUS_VER,
    ],
    implementation = _tf_pip_impl,
)