batch_llm_inference.py 2.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
This example shows how to use Ray Data for data parallel batch inference.

Ray Data is a data processing framework that can handle large datasets
and integrates tightly with vLLM for data-parallel inference.

As of Ray 2.44, Ray Data has a native integration with
vLLM (under ray.data.llm).

Ray Data provides functionality for:
* Reading and writing to cloud storage (S3, GCS, etc.)
* Automatic sharding and load-balancing across a cluster
* Optimized configuration of vLLM using continuous batching
* Compatible with tensor/pipeline parallel inference as well.

Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html
"""
21

22
23
24
25
import ray
from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

26
27
28
assert Version(ray.__version__) >= Version("2.44.1"), (
    "Ray version must be at least 2.44.1"
)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

# Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False)
# ray.data.DataContext.get_current().enable_progress_bars = False

# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
print(ds.schema())

size = ds.count()
print(f"Size of dataset: {size} prompts")

# Configure vLLM engine.
config = vLLMEngineProcessorConfig(
    model_source="unsloth/Llama-3.1-8B-Instruct",
    engine_kwargs={
        "enable_chunked_prefill": True,
        "max_num_batched_tokens": 4096,
        "max_model_len": 16384,
    },
    concurrency=1,  # set the number of parallel vLLM replicas
    batch_size=64,
)

# Create a Processor object, which will be used to
# do batch inference on the dataset
vllm_processor = build_llm_processor(
    config,
    preprocess=lambda row: dict(
59
60
61
62
        messages=[
            {"role": "system", "content": "You are a bot that responds with haikus."},
            {"role": "user", "content": row["text"]},
        ],
63
64
65
        sampling_params=dict(
            temperature=0.3,
            max_tokens=250,
66
67
        ),
    ),
68
69
    postprocess=lambda row: dict(
        answer=row["generated_text"],
70
        **row,  # This will return all the original columns in the dataset.
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    ),
)

ds = vllm_processor(ds)

# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)

for output in outputs:
    prompt = output["prompt"]
    generated_text = output["generated_text"]
    print(f"Prompt: {prompt!r}")
    print(f"Generated text: {generated_text!r}")

# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")