Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4abf6336
Unverified
Commit
4abf6336
authored
Feb 02, 2024
by
Cheng Su
Committed by
GitHub
Feb 02, 2024
Browse files
Add one example to run batch inference distributed on Ray (#2696)
parent
0e163fce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
0 deletions
+70
-0
examples/offline_inference_distributed.py
examples/offline_inference_distributed.py
+70
-0
No files found.
examples/offline_inference_distributed.py
0 → 100644
View file @
4abf6336
"""
This example shows how to use Ray Data for running offline batch inference
distributively on a multi-nodes cluster.
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""
from
vllm
import
LLM
,
SamplingParams
from
typing
import
Dict
import
numpy
as
np
import
ray
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create a class to do batch inference.
class
LLMPredictor
:
def
__init__
(
self
):
# Create an LLM.
self
.
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-chat-hf"
)
def
__call__
(
self
,
batch
:
Dict
[
str
,
np
.
ndarray
])
->
Dict
[
str
,
list
]:
# Generate texts from the prompts.
# The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information.
outputs
=
self
.
llm
.
generate
(
batch
[
"text"
],
sampling_params
)
prompt
=
[]
generated_text
=
[]
for
output
in
outputs
:
prompt
.
append
(
output
.
prompt
)
generated_text
.
append
(
' '
.
join
([
o
.
text
for
o
in
output
.
outputs
]))
return
{
"prompt"
:
prompt
,
"generated_text"
:
generated_text
,
}
# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds
=
ray
.
data
.
read_text
(
"s3://anonymous@air-example-data/prompts.txt"
)
# Apply batch inference for all input data.
ds
=
ds
.
map_batches
(
LLMPredictor
,
# Set the concurrency to the number of LLM instances.
concurrency
=
10
,
# Specify the number of GPUs required per LLM instance.
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
# (i.e., `tensor_parallel_size`).
num_gpus
=
1
,
# Specify the batch size for inference.
batch_size
=
32
,
)
# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs
=
ds
.
take
(
limit
=
10
)
for
output
in
outputs
:
prompt
=
output
[
"prompt"
]
generated_text
=
output
[
"generated_text"
]
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment