Commit 6f8a9610 authored by myhloli's avatar myhloli
Browse files

feat: implement S3 data reader and writer with multi-bucket support

parent bd927919
from .base import DataReader, DataWriter
from .dummy import DummyDataWriter
from .filebase import FileBasedDataReader, FileBasedDataWriter
from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
from .s3 import S3DataReader, S3DataWriter
__all__ = [
"DataReader",
"DataWriter",
"FileBasedDataReader",
"FileBasedDataWriter",
"S3DataReader",
"S3DataWriter",
"MultiBucketS3DataReader",
"MultiBucketS3DataWriter",
"DummyDataWriter",
]
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
def safe_encode(data: str, method: str):
try:
bit_data = data.encode(encoding=method, errors='replace')
return bit_data, True
except: # noqa
return None, False
for method in ['utf-8', 'ascii']:
bit_data, flag = safe_encode(data, method)
if flag:
self.write(path, bit_data)
break
from .base import DataWriter
class DummyDataWriter(DataWriter):
def write(self, path: str, data: bytes) -> None:
"""Dummy write method that does nothing."""
pass
def write_string(self, path: str, data: str) -> None:
"""Dummy write_string method that does nothing."""
pass
import os
from .base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
os.makedirs(os.path.dirname(fn_path), exist_ok=True)
with open(fn_path, 'wb') as f:
f.write(data)
from ..utils.exceptions import InvalidConfig, InvalidParams
from .base import DataReader, DataWriter
from ..io.s3 import S3Reader, S3Writer
from ..utils.schemas import S3Config
from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
class MultiS3Mixin:
def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs.
InvalidConfig: bucket name not unique in s3_configs.
InvalidConfig: default bucket must be provided.
"""
if len(default_prefix) == 0:
raise InvalidConfig('default_prefix must be provided')
arr = default_prefix.strip('/').split('/')
self.default_bucket = arr[0]
self.default_prefix = '/'.join(arr[1:])
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == self.default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the bucket, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
for example: s3://bucket_name/path?0,100.
Returns:
bytes: the content of s3 file.
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the bucket.
Args:
path (str): the file path.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content.
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the bucket.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write.
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_writer.write(path, data)
from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
from ..utils.schemas import S3Config
class S3DataReader(MultiBucketS3DataReader):
def __init__(
self,
default_prefix_without_bucket: str,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
default_prefix_without_bucket: prefix that not contains bucket
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
f'{bucket}/{default_prefix_without_bucket}',
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
class S3DataWriter(MultiBucketS3DataWriter):
def __init__(
self,
default_prefix_without_bucket: str,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 writer client.
Args:
default_prefix_without_bucket: prefix that not contains bucket
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
f'{bucket}/{default_prefix_without_bucket}',
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
from .base import IOReader, IOWriter
from .http import HttpReader, HttpWriter
from .s3 import S3Reader, S3Writer
__all__ = ['IOReader', 'IOWriter', 'HttpReader', 'HttpWriter', 'S3Reader', 'S3Writer']
\ No newline at end of file
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
import io
import requests
from .base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from ..io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
class CUDA_NOT_AVAILABLE(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'CUDA not available: {self.msg}'
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
def remove_non_official_s3_args(s3path):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
"""
arr = s3path.split("?")
return arr[0]
def parse_s3path(s3path: str):
# from s3pathlib import S3Path
# p = S3Path(remove_non_official_s3_args(s3path))
# return p.bucket, p.key
s3path = remove_non_official_s3_args(s3path).strip()
if s3path.startswith(('s3://', 's3a://')):
prefix, path = s3path.split('://', 1)
bucket_name, key = path.split('/', 1)
return bucket_name, key
elif s3path.startswith('/'):
raise ValueError("The provided path starts with '/'. This does not conform to a valid S3 path format.")
else:
raise ValueError("Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'.")
def parse_s3_range_params(s3path: str):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
"""
arr = s3path.split("?bytes=")
if len(arr) == 1:
return None
return arr[1].split(",")
# Copyright (c) Opendatalab. All rights reserved.
from pydantic import BaseModel, Field
class S3Config(BaseModel):
"""S3 config
"""
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
"""The width and height of page
"""
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment