multi_bucket_s3.py 5.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

from ..utils.exceptions import InvalidConfig, InvalidParams
from .base import DataReader, DataWriter
from ..io.s3 import S3Reader, S3Writer
from ..utils.schemas import S3Config
from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args


class MultiS3Mixin:
    def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
        """Initialized with multiple s3 configs.

        Args:
            default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
            s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.

        Raises:
            InvalidConfig: default bucket config not in s3_configs.
            InvalidConfig: bucket name not unique in s3_configs.
            InvalidConfig: default bucket must be provided.
        """
        if len(default_prefix) == 0:
            raise InvalidConfig('default_prefix must be provided')

        arr = default_prefix.strip('/').split('/')
        self.default_bucket = arr[0]
        self.default_prefix = '/'.join(arr[1:])

        found_default_bucket_config = False
        for conf in s3_configs:
            if conf.bucket_name == self.default_bucket:
                found_default_bucket_config = True
                break

        if not found_default_bucket_config:
            raise InvalidConfig(
                f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
            )

        uniq_bucket = set([conf.bucket_name for conf in s3_configs])
        if len(uniq_bucket) != len(s3_configs):
            raise InvalidConfig(
                f'the bucket_name in s3_configs: {s3_configs} must be unique'
            )

        self.s3_configs = s3_configs
        self._s3_clients_h: dict = {}


class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
    def read(self, path: str) -> bytes:
        """Read the path from s3, select diffect bucket client for each request
        based on the bucket, also support range read.

        Args:
            path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
            for example: s3://bucket_name/path?0,100.

        Returns:
            bytes: the content of s3 file.
        """
        may_range_params = parse_s3_range_params(path)
        if may_range_params is None or 2 != len(may_range_params):
            byte_start, byte_len = 0, -1
        else:
            byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
        path = remove_non_official_s3_args(path)
        return self.read_at(path, byte_start, byte_len)

    def __get_s3_client(self, bucket_name: str):
        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
            raise InvalidParams(
                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
            )
        if bucket_name not in self._s3_clients_h:
            conf = next(
                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
            )
            self._s3_clients_h[bucket_name] = S3Reader(
                bucket_name,
                conf.access_key,
                conf.secret_key,
                conf.endpoint_url,
                conf.addressing_style,
            )
        return self._s3_clients_h[bucket_name]

    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
        """Read the file with offset and limit, select diffect bucket client
        for each request based on the bucket.

        Args:
            path (str): the file path.
            offset (int, optional): the number of bytes skipped. Defaults to 0.
            limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.

        Returns:
            bytes: the file content.
        """
        if path.startswith('s3://'):
            bucket_name, path = parse_s3path(path)
            s3_reader = self.__get_s3_client(bucket_name)
        else:
            s3_reader = self.__get_s3_client(self.default_bucket)
            if self.default_prefix:
                path = self.default_prefix + '/' + path
        return s3_reader.read_at(path, offset, limit)


class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
    def __get_s3_client(self, bucket_name: str):
        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
            raise InvalidParams(
                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
            )
        if bucket_name not in self._s3_clients_h:
            conf = next(
                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
            )
            self._s3_clients_h[bucket_name] = S3Writer(
                bucket_name,
                conf.access_key,
                conf.secret_key,
                conf.endpoint_url,
                conf.addressing_style,
            )
        return self._s3_clients_h[bucket_name]

    def write(self, path: str, data: bytes) -> None:
        """Write file with data, also select diffect bucket client for each
        request based on the bucket.

        Args:
            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
            data (bytes): the data want to write.
        """
        if path.startswith('s3://'):
            bucket_name, path = parse_s3path(path)
            s3_writer = self.__get_s3_client(bucket_name)
        else:
            s3_writer = self.__get_s3_client(self.default_bucket)
            if self.default_prefix:
                path = self.default_prefix + '/' + path
        return s3_writer.write(path, data)