s3_utils.py 2.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import fnmatch
5
from typing import TYPE_CHECKING, Optional
6

7
8
from vllm.utils import PlaceholderModule

9
10
11
if TYPE_CHECKING:
    from botocore.client import BaseClient

12
13
14
15
try:
    import boto3
except ImportError:
    boto3 = PlaceholderModule("boto3")  # type: ignore[assignment]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
    return [
        path for path in paths if any(
            fnmatch.fnmatch(path, pattern) for pattern in patterns)
    ]


def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
    return [
        path for path in paths
        if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
    ]


32
def glob(s3: Optional["BaseClient"] = None,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
         path: str = "",
         allow_pattern: Optional[list[str]] = None) -> list[str]:
    """
    List full file names from S3 path and filter by allow pattern.

    Args:
        s3: S3 client to use.
        path: The S3 path to list from.
        allow_pattern: A list of patterns of which files to pull.

    Returns:
        list[str]: List of full S3 paths allowed by the pattern
    """
    if s3 is None:
47
        s3 = boto3.client("s3")
48
49
    if not path.endswith("/"):
        path = path + "/"
50
51
52
53
54
55
56
    bucket_name, _, paths = list_files(s3,
                                       path=path,
                                       allow_pattern=allow_pattern)
    return [f"s3://{bucket_name}/{path}" for path in paths]


def list_files(
57
        s3: "BaseClient",
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        path: str,
        allow_pattern: Optional[list[str]] = None,
        ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
    """
    List files from S3 path and filter by pattern.

    Args:
        s3: S3 client to use.
        path: The S3 path to list from.
        allow_pattern: A list of patterns of which files to pull.
        ignore_pattern: A list of patterns of which files not to pull.

    Returns:
        tuple[str, str, list[str]]: A tuple where:
            - The first element is the bucket name
            - The second element is string represent the bucket 
              and the prefix as a dir like string
            - The third element is a list of files allowed or 
              disallowed by pattern
    """
    parts = path.removeprefix('s3://').split('/')
    prefix = '/'.join(parts[1:])
    bucket_name = parts[0]

    objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    paths = [obj['Key'] for obj in objects.get('Contents', [])]

    paths = _filter_ignore(paths, ["*/"])
    if allow_pattern is not None:
        paths = _filter_allow(paths, allow_pattern)

    if ignore_pattern is not None:
        paths = _filter_ignore(paths, ignore_pattern)

    return bucket_name, prefix, paths