utils_s3.py 5.1 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from typing import Any, Dict, NamedTuple, Protocol, Tuple

import torch

try:
    import boto3
    import botocore.exceptions as exceptions
except ModuleNotFoundError:
    pass

S3_PREFIX = "s3://"


class S3Config(NamedTuple):
    """Config when the data (.bin) file and the index (.idx) file are in S3

    TODO: These parameters are few and can be consolidated with parameters specific to bin reader
    classes - @jkamalu

    Attributes:

        path_to_idx_cache (str): The local directory where we will store the index (.idx) file

        bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3 at each call of the `read` method in _S3BinReader, which is slow, because each request has a fixed cost independent of the size of the byte range requested. If the number of bytes is too large, then we only rarely have to send requests to S3, but it takes a lot of time to complete the request when we do, which can block training. We've found that 256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much effort into tuning it), so we default to it.
    """

    path_to_idx_cache: str

    bin_chunk_nbytes: int = 256 * 1024 * 1024


class S3Client(Protocol):
    """The protocol which all s3 clients should abide by"""

    def download_file(self, Bucket: str, Key: str, Filename: str) -> None: ...

    def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ...

    def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: ...

    def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: ...

    def close(self) -> None: ...


def is_s3_path(path: str) -> bool:
    """Ascertain whether a path is in S3

    Args:
        path (str): The path

    Returns:
        bool: True if the path is in S3, False otherwise
    """
    return path.startswith(S3_PREFIX)


def parse_s3_path(path: str) -> Tuple[str, str]:
    """Parses the given S3 path returning correspsonding bucket and key.

    Args:
        path (str): The S3 path

    Returns:
        Tuple[str, str]: A (bucket, key) tuple
    """
    assert is_s3_path(path)
    parts = path.replace(S3_PREFIX, "").split("/")
    bucket = parts[0]
    if len(parts) > 1:
        key = "/".join(parts[1:])
        assert S3_PREFIX + bucket + "/" + key == path
    else:
        key = ""
    return bucket, key


def object_exists(client: S3Client, path: str) -> bool:
    """Ascertain whether the object at the given S3 path exists in S3

    Args:
        client (S3Client): The S3 client

        path (str): The S3 path

    Raises:
        botocore.exceptions.ClientError: The error code is 404

    Returns:
        bool: True if the object exists in S3, False otherwise
    """
    parsed_s3_path = parse_s3_path(path)
    try:
        response = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1])
    except exceptions.ClientError as e:
        if e.response["Error"]["Code"] != "404":
            raise e
    return True


def _download_file(client: S3Client, s3_path: str, local_path: str) -> None:
    """Download the object at the given S3 path to the given local file system path

    Args:
        client (S3Client): The S3 client

        s3_path (str): The S3 source path

        local_path (str): The local destination path
    """
    dirname = os.path.dirname(local_path)
    os.makedirs(dirname, exist_ok=True)
    parsed_s3_path = parse_s3_path(s3_path)
    client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path)


def maybe_download_file(s3_path: str, local_path: str) -> None:
    """Download the object at the given S3 path to the given local file system path

    In a distributed setting, downloading the S3 object proceeds in stages in order
    to try to have the minimum number of processes download the object in order for
    all the ranks to have access to the downloaded object.

    Args:
        s3_path (str): The S3 source path

        local_path (str): The local destination path
    """

    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        local_rank = rank % torch.cuda.device_count()
    else:
        rank = 0
        local_rank = 0

    s3_client = boto3.client("s3")

    if (not os.path.exists(local_path)) and (rank == 0):
        _download_file(s3_client, s3_path, local_path)

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    # If the `local_path` is in a file system that is not
    # shared across all the ranks, then we assume it's in the
    # host file system and each host needs to download the file.
    if (not os.path.exists(local_path)) and (local_rank == 0):
        _download_file(s3_client, s3_path, local_path)

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    # If the `local_path` still does not exist, then we assume
    # each rank is saving to a separate location.
    if not os.path.exists(local_path):
        _download_file(s3_client, s3_path, local_path)

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    assert os.path.exists(local_path)