distributed.py 3.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
This example shows how to use Ray Data for running offline batch inference
distributively on a multi-nodes cluster.

Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""

9
from typing import Any, Dict, List
10

11
12
import numpy as np
import ray
13
14
from packaging.version import Version
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
15

16
17
from vllm import LLM, SamplingParams

18
19
20
assert Version(ray.__version__) >= Version(
    "2.22.0"), "Ray version must be at least 2.22.0"

21
22
23
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

24
25
26
27
28
29
# Set tensor parallelism per instance.
tensor_parallel_size = 1

# Set number of instances. Each instance will use tensor_parallel_size GPUs.
num_instances = 1

30
31
32
33
34
35

# Create a class to do batch inference.
class LLMPredictor:

    def __init__(self):
        # Create an LLM.
36
37
        self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
                       tensor_parallel_size=tensor_parallel_size)
38
39
40
41
42
43

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
        # Generate texts from the prompts.
        # The output is a list of RequestOutput objects that contain the prompt,
        # generated text, and other information.
        outputs = self.llm.generate(batch["text"], sampling_params)
44
45
        prompt: List[str] = []
        generated_text: List[str] = []
46
47
48
49
50
51
52
53
54
55
56
57
58
        for output in outputs:
            prompt.append(output.prompt)
            generated_text.append(' '.join([o.text for o in output.outputs]))
        return {
            "prompt": prompt,
            "generated_text": generated_text,
        }


# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

# For tensor_parallel_size > 1, we need to create placement groups for vLLM
# to use. Every actor has to have its own placement group.
def scheduling_strategy_fn():
    # One bundle per tensor parallel worker
    pg = ray.util.placement_group(
        [{
            "GPU": 1,
            "CPU": 1
        }] * tensor_parallel_size,
        strategy="STRICT_PACK",
    )
    return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
        pg, placement_group_capture_child_tasks=True))


75
resources_kwarg: Dict[str, Any] = {}
76
77
78
79
80
81
82
83
84
85
if tensor_parallel_size == 1:
    # For tensor_parallel_size == 1, we simply set num_gpus=1.
    resources_kwarg["num_gpus"] = 1
else:
    # Otherwise, we have to set num_gpus=0 and provide
    # a function that will create a placement group for
    # each instance.
    resources_kwarg["num_gpus"] = 0
    resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn

86
87
88
89
# Apply batch inference for all input data.
ds = ds.map_batches(
    LLMPredictor,
    # Set the concurrency to the number of LLM instances.
90
    concurrency=num_instances,
91
92
    # Specify the batch size for inference.
    batch_size=32,
93
    **resources_kwarg,
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
)

# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)
for output in outputs:
    prompt = output["prompt"]
    generated_text = output["generated_text"]
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")