relocate.py 12.7 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

"""Helper script to package wheels and relocate binaries."""

import glob
import hashlib
7
8
9
10
11
import io

# Standard library imports
import os
import os.path as osp
12
import platform
13
import shutil
14
import subprocess
15
16
import sys
import zipfile
17
18
19
from base64 import urlsafe_b64encode

# Third party imports
20
if sys.platform == "linux":
21
22
23
    from auditwheel.lddtree import lddtree


24
ALLOWLIST = {
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    "libgcc_s.so.1",
    "libstdc++.so.6",
    "libm.so.6",
    "libdl.so.2",
    "librt.so.1",
    "libc.so.6",
    "libnsl.so.1",
    "libutil.so.1",
    "libpthread.so.0",
    "libresolv.so.2",
    "libX11.so.6",
    "libXext.so.6",
    "libXrender.so.1",
    "libICE.so.6",
    "libSM.so.6",
    "libGL.so.1",
    "libgobject-2.0.so.0",
    "libgthread-2.0.so.0",
    "libglib-2.0.so.0",
    "ld-linux-x86-64.so.2",
    "ld-2.17.so",
46
47
}

48
WINDOWS_ALLOWLIST = {
49
50
51
52
53
54
55
56
57
58
59
60
    "MSVCP140.dll",
    "KERNEL32.dll",
    "VCRUNTIME140_1.dll",
    "VCRUNTIME140.dll",
    "api-ms-win-crt-heap-l1-1-0.dll",
    "api-ms-win-crt-runtime-l1-1-0.dll",
    "api-ms-win-crt-stdio-l1-1-0.dll",
    "api-ms-win-crt-filesystem-l1-1-0.dll",
    "api-ms-win-crt-string-l1-1-0.dll",
    "api-ms-win-crt-environment-l1-1-0.dll",
    "api-ms-win-crt-math-l1-1-0.dll",
    "api-ms-win-crt-convert-l1-1-0.dll",
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
}


HERE = osp.dirname(osp.abspath(__file__))
PACKAGE_ROOT = osp.dirname(osp.dirname(HERE))
PLATFORM_ARCH = platform.machine()
PYTHON_VERSION = sys.version_info


def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE):
    """Yield pieces of data from a file-like object until EOF."""
    while True:
        chunk = file.read(size)
        if not chunk:
            break
        yield chunk


def rehash(path, blocksize=1 << 20):
    """Return (hash, length) for path using hashlib.sha256()"""
    h = hashlib.sha256()
    length = 0
83
    with open(path, "rb") as f:
84
85
86
        for block in read_chunks(f, size=blocksize):
            length += len(block)
            h.update(block)
87
    digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=")
88
89
90
91
92
93
    # unicode/str python2 issues
    return (digest, str(length))  # type: ignore


def unzip_file(file, dest):
    """Decompress zip `file` into directory `dest`."""
94
    with zipfile.ZipFile(file, "r") as zip_ref:
95
96
97
98
99
100
101
102
103
104
        zip_ref.extractall(dest)


def is_program_installed(basename):
    """
    Return program absolute path if installed in PATH.
    Otherwise, return None
    On macOS systems, a .app is considered installed if
    it exists.
    """
105
    if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        return basename

    for path in os.environ["PATH"].split(os.pathsep):
        abspath = osp.join(path, basename)
        if osp.isfile(abspath):
            return abspath


def find_program(basename):
    """
    Find program in PATH and return absolute path
    Try adding .exe or .bat to basename on Windows platforms
    (return None if not found)
    """
    names = [basename]
121
    if os.name == "nt":
122
        # Windows platforms
123
        extensions = (".exe", ".bat", ".cmd", ".dll")
124
125
126
127
128
129
130
131
132
133
        if not basename.endswith(extensions):
            names = [basename + ext for ext in extensions] + [basename]
    for name in names:
        path = is_program_installed(name)
        if path:
            return path


def patch_new_path(library_path, new_dir):
    library = osp.basename(library_path)
134
135
136
137
    name, *rest = library.split(".")
    rest = ".".join(rest)
    hash_id = hashlib.sha256(library_path.encode("utf-8")).hexdigest()[:8]
    new_name = ".".join([name, hash_id, rest])
138
139
140
141
    return osp.join(new_dir, new_name)


def find_dll_dependencies(dumpbin, binary):
142
143
144
145
    out = subprocess.run([dumpbin, "/dependents", binary], stdout=subprocess.PIPE)
    out = out.stdout.strip().decode("utf-8")
    start_index = out.find("dependencies:") + len("dependencies:")
    end_index = out.find("Summary")
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    dlls = out[start_index:end_index].strip()
    dlls = dlls.split(os.linesep)
    dlls = [dll.strip() for dll in dlls]
    return dlls


def relocate_elf_library(patchelf, output_dir, output_library, binary):
    """
    Relocate an ELF shared library to be packaged on a wheel.

    Given a shared library, find the transitive closure of its dependencies,
    rename and copy them into the wheel while updating their respective rpaths.
    """

160
    print("Relocating {0}".format(binary))
161
162
163
    binary_path = osp.join(output_library, binary)

    ld_tree = lddtree(binary_path)
164
    tree_libs = ld_tree["libs"]
165

166
    binary_queue = [(n, binary) for n in ld_tree["needed"]]
167
168
169
170
171
172
173
174
    binary_paths = {binary: binary_path}
    binary_dependencies = {}

    while binary_queue != []:
        library, parent = binary_queue.pop(0)
        library_info = tree_libs[library]
        print(library)

175
176
        if library_info["path"] is None:
            print("Omitting {0}".format(library))
177
178
            continue

179
        if library in ALLOWLIST:
180
            # Omit glibc/gcc/system libraries
181
            print("Omitting {0}".format(library))
182
183
184
185
186
187
188
189
190
            continue

        parent_dependencies = binary_dependencies.get(parent, [])
        parent_dependencies.append(library)
        binary_dependencies[parent] = parent_dependencies

        if library in binary_paths:
            continue

191
192
        binary_paths[library] = library_info["path"]
        binary_queue += [(n, library) for n in library_info["needed"]]
193

194
195
    print("Copying dependencies to wheel directory")
    new_libraries_path = osp.join(output_dir, "torchvision.libs")
196
197
198
199
200
201
202
203
    os.makedirs(new_libraries_path)

    new_names = {binary: binary_path}

    for library in binary_paths:
        if library != binary:
            library_path = binary_paths[library]
            new_library_path = patch_new_path(library_path, new_libraries_path)
204
            print("{0} -> {1}".format(library, new_library_path))
205
206
207
            shutil.copyfile(library_path, new_library_path)
            new_names[library] = new_library_path

208
    print("Updating dependency names by new files")
209
210
211
212
213
214
215
216
    for library in binary_paths:
        if library != binary:
            if library not in binary_dependencies:
                continue
            library_dependencies = binary_dependencies[library]
            new_library_name = new_names[library]
            for dep in library_dependencies:
                new_dep = osp.basename(new_names[dep])
217
                print("{0}: {1} -> {2}".format(library, dep, new_dep))
218
                subprocess.check_output(
219
220
221
222
223
224
225
                    [patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path
                )

            print("Updating library rpath")
            subprocess.check_output([patchelf, "--set-rpath", "$ORIGIN", new_library_name], cwd=new_libraries_path)

            subprocess.check_output([patchelf, "--print-rpath", new_library_name], cwd=new_libraries_path)
226
227
228
229
230

    print("Update library dependencies")
    library_dependencies = binary_dependencies[binary]
    for dep in library_dependencies:
        new_dep = osp.basename(new_names[dep])
231
232
233
234
        print("{0}: {1} -> {2}".format(binary, dep, new_dep))
        subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library)

    print("Update library rpath")
235
    subprocess.check_output(
236
        [patchelf, "--set-rpath", "$ORIGIN:$ORIGIN/../torchvision.libs", binary_path], cwd=output_library
237
238
239
240
    )


def relocate_dll_library(dumpbin, output_dir, output_library, binary):
241
242
243
244
245
246
    """
    Relocate a DLL/PE shared library to be packaged on a wheel.

    Given a shared library, find the transitive closure of its dependencies,
    rename and copy them into the wheel.
    """
247
    print("Relocating {0}".format(binary))
248
249
250
251
252
253
254
255
256
    binary_path = osp.join(output_library, binary)

    library_dlls = find_dll_dependencies(dumpbin, binary_path)
    binary_queue = [(dll, binary) for dll in library_dlls]
    binary_paths = {binary: binary_path}
    binary_dependencies = {}

    while binary_queue != []:
        library, parent = binary_queue.pop(0)
257
258
        if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"):
            print("Omitting {0}".format(library))
259
260
261
262
            continue

        library_path = find_program(library)
        if library_path is None:
263
            print("{0} not found".format(library))
264
265
            continue

266
        if osp.basename(osp.dirname(library_path)) == "system32":
267
268
            continue

269
        print("{0}: {1}".format(library, library_path))
270
271
272
273
274
275
276
277
278
279
280
        parent_dependencies = binary_dependencies.get(parent, [])
        parent_dependencies.append(library)
        binary_dependencies[parent] = parent_dependencies

        if library in binary_paths:
            continue

        binary_paths[library] = library_path
        downstream_dlls = find_dll_dependencies(dumpbin, library_path)
        binary_queue += [(n, library) for n in downstream_dlls]

281
282
    print("Copying dependencies to wheel directory")
    package_dir = osp.join(output_dir, "torchvision")
283
284
285
286
    for library in binary_paths:
        if library != binary:
            library_path = binary_paths[library]
            new_library_path = osp.join(package_dir, library)
287
            print("{0} -> {1}".format(library, new_library_path))
288
289
290
291
292
            shutil.copyfile(library_path, new_library_path)


def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
    """Create RECORD file and compress wheel distribution."""
293
294
295
    print("Update RECORD file in wheel")
    dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0]
    record_file = osp.join(dist_info, "RECORD")
296

297
    with open(record_file, "w") as f:
298
299
300
301
302
        for root, _, files in os.walk(output_dir):
            for this_file in files:
                full_file = osp.join(root, this_file)
                rel_file = osp.relpath(full_file, output_dir)
                if full_file == record_file:
303
                    f.write("{0},,\n".format(rel_file))
304
305
                else:
                    digest, size = rehash(full_file)
306
                    f.write("{0},{1},{2}\n".format(rel_file, digest, size))
307

308
    print("Compressing wheel")
309
    base_wheel_name = osp.join(wheel_dir, wheel_name)
310
    shutil.make_archive(base_wheel_name, "zip", output_dir)
311
    os.remove(wheel)
312
    shutil.move("{0}.zip".format(base_wheel_name), wheel)
313
314
315
316
317
    shutil.rmtree(output_dir)


def patch_linux():
    # Get patchelf location
318
    patchelf = find_program("patchelf")
319
    if patchelf is None:
320
321
322
        raise FileNotFoundError(
            "Patchelf was not found in the system, please" " make sure that is available on the PATH."
        )
323
324

    # Find wheel
325
326
327
    print("Finding wheels...")
    wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
    output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
328

329
330
    image_binary = "image.so"
    video_binary = "video_reader.so"
331
332
333
334
335
336
337
    torchvision_binaries = [image_binary, video_binary]
    for wheel in wheels:
        if osp.exists(output_dir):
            shutil.rmtree(output_dir)

        os.makedirs(output_dir)

338
        print("Unzipping wheel...")
339
340
        wheel_file = osp.basename(wheel)
        wheel_dir = osp.dirname(wheel)
341
        print("{0}".format(wheel_file))
342
343
344
        wheel_name, _ = osp.splitext(wheel_file)
        unzip_file(wheel, output_dir)

345
346
        print("Finding ELF dependencies...")
        output_library = osp.join(output_dir, "torchvision")
347
348
        for binary in torchvision_binaries:
            if osp.exists(osp.join(output_library, binary)):
349
                relocate_elf_library(patchelf, output_dir, output_library, binary)
350
351
352
353
354
355

        compress_wheel(output_dir, wheel, wheel_dir, wheel_name)


def patch_win():
    # Get dumpbin location
356
    dumpbin = find_program("dumpbin")
357
    if dumpbin is None:
358
359
360
        raise FileNotFoundError(
            "Dumpbin was not found in the system, please" " make sure that is available on the PATH."
        )
361
362

    # Find wheel
363
364
365
    print("Finding wheels...")
    wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
    output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
366

367
368
    image_binary = "image.pyd"
    video_binary = "video_reader.pyd"
369
370
371
372
373
374
375
    torchvision_binaries = [image_binary, video_binary]
    for wheel in wheels:
        if osp.exists(output_dir):
            shutil.rmtree(output_dir)

        os.makedirs(output_dir)

376
        print("Unzipping wheel...")
377
378
        wheel_file = osp.basename(wheel)
        wheel_dir = osp.dirname(wheel)
379
        print("{0}".format(wheel_file))
380
381
382
        wheel_name, _ = osp.splitext(wheel_file)
        unzip_file(wheel, output_dir)

383
384
        print("Finding DLL/PE dependencies...")
        output_library = osp.join(output_dir, "torchvision")
385
386
        for binary in torchvision_binaries:
            if osp.exists(osp.join(output_library, binary)):
387
                relocate_dll_library(dumpbin, output_dir, output_library, binary)
388
389
390
391

        compress_wheel(output_dir, wheel, wheel_dir, wheel_name)


392
393
if __name__ == "__main__":
    if sys.platform == "linux":
394
        patch_linux()
395
    elif sys.platform == "win32":
396
        patch_win()