Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
...@@ -17,7 +17,7 @@ repos: ...@@ -17,7 +17,7 @@ repos:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
files: ^(.buildkite|benchmarks)/.* files: ^(.buildkite|benchmarks|examples)/.*
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
hooks: hooks:
......
...@@ -6,6 +6,7 @@ with the correct prompt format on audio language models. ...@@ -6,6 +6,7 @@ with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
import os import os
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
...@@ -22,7 +23,7 @@ audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] ...@@ -22,7 +23,7 @@ audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = { question_per_audio_count = {
0: "What is 1+1?", 0: "What is 1+1?",
1: "What is recited in the audio?", 1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?" 2: "What sport and what nursery rhyme are referenced?",
} }
...@@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ...@@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# MiniCPM-O # MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6" model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
...@@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: ...@@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
audio_placeholder = "(<audio>./</audio>)" * audio_count audio_placeholder = "(<audio>./</audio>)" * audio_count
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
messages = [{ messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
'role': 'user', prompt = tokenizer.apply_chat_template(
'content': f'{audio_placeholder}\n{question}' messages,
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
chat_template=audio_chat_template) chat_template=audio_chat_template,
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: ...@@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
# Since the vision-lora and speech-lora co-exist with the base model, # Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights. # we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora") speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)]) placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>" prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
...@@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: ...@@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
audio_in_prompt = "".join([ audio_in_prompt = "".join(
f"Audio {idx+1}: " [
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
]) for idx in range(audio_count)
]
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" prompt = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n" "<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n" f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n") "<|im_start|>assistant\n"
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int): ...@@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int):
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
audio_in_prompt = "".join([ audio_in_prompt = "".join(
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
]) )
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n" "<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n" f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n") "<|im_start|>assistant\n"
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
...@@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: ...@@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{ messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
'role': 'user', prompt = tokenizer.apply_chat_template(
'content': "<|audio|>\n" * audio_count + question messages, tokenize=False, add_generation_prompt=True
}] )
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
...@@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: ...@@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
# Whisper # Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData: def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, ( assert audio_count == 1, "Whisper only support single audio input per prompt"
"Whisper only support single audio input per prompt")
model_name = "openai/whisper-large-v3-turbo" model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>" prompt = "<|startoftranscript|>"
...@@ -252,27 +254,33 @@ model_example_map = { ...@@ -252,27 +254,33 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'audio language models') "audio language models"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
"--model-type",
"-m",
type=str, type=str,
default="ultravox", default="ultravox",
choices=model_example_map.keys(), choices=model_example_map.keys(),
help='Huggingface "model_type".') help='Huggingface "model_type".',
parser.add_argument('--num-prompts', )
type=int, parser.add_argument(
default=1, "--num-prompts", type=int, default=1, help="Number of prompts to run."
help='Number of prompts to run.') )
parser.add_argument("--num-audios", parser.add_argument(
"--num-audios",
type=int, type=int,
default=1, default=1,
choices=[0, 1, 2], choices=[0, 1, 2],
help="Number of audio items per prompt.") help="Number of audio items per prompt.",
parser.add_argument("--seed", )
parser.add_argument(
"--seed",
type=int, type=int,
default=None, default=None,
help="Set the seed when initializing `vllm.LLM`.") help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
...@@ -283,29 +291,30 @@ def main(args): ...@@ -283,29 +291,30 @@ def main(args):
raise ValueError(f"Model type {model} is not supported.") raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios audio_count = args.num_audios
req_data = model_example_map[model](question_per_audio_count[audio_count], req_data = model_example_map[model](
audio_count) question_per_audio_count[audio_count], audio_count
)
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, sampling_params = SamplingParams(
max_tokens=64, temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
mm_data = {} mm_data = {}
if audio_count > 0: if audio_count > 0:
mm_data = { mm_data = {
"audio": [ "audio": [
asset.audio_and_sample_rate asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
for asset in audio_assets[:audio_count]
] ]
} }
...@@ -315,8 +324,9 @@ def main(args): ...@@ -315,8 +324,9 @@ def main(args):
# Batch inference # Batch inference
inputs = [inputs] * args.num_prompts inputs = [inputs] * args.num_prompts
# Add LoRA request if applicable # Add LoRA request if applicable
lora_request = (req_data.lora_requests * lora_request = (
args.num_prompts if req_data.lora_requests else None) req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
outputs = llm.generate( outputs = llm.generate(
inputs, inputs,
......
...@@ -16,13 +16,16 @@ but ask different questions. ...@@ -16,13 +16,16 @@ but ask different questions.
Run: Run:
python examples/offline_inference/automatic_prefix_caching.py python examples/offline_inference/automatic_prefix_caching.py
""" """
import time import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# ruff: noqa: E501 # ruff: noqa: E501
# A prompt containing a large markdown table. The table is randomly generated by GPT-4. # A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ LONG_PROMPT = (
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
+ """
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address | | ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| |-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | | 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
...@@ -56,6 +59,7 @@ LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables i ...@@ -56,6 +59,7 @@ LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables i
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | | 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | | 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
""" """
)
def get_generation_time(llm, sampling_params, prompts): def get_generation_time(llm, sampling_params, prompts):
...@@ -72,7 +76,7 @@ def get_generation_time(llm, sampling_params, prompts): ...@@ -72,7 +76,7 @@ def get_generation_time(llm, sampling_params, prompts):
def main(): def main():
# set enable_prefix_caching=True to enable APC # set enable_prefix_caching=True to enable APC
llm = LLM(model='lmsys/longchat-13b-16k', enable_prefix_caching=True) llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
sampling_params = SamplingParams(temperature=0, max_tokens=100) sampling_params = SamplingParams(temperature=0, max_tokens=100)
...@@ -80,8 +84,8 @@ def main(): ...@@ -80,8 +84,8 @@ def main():
get_generation_time( get_generation_time(
llm, llm,
sampling_params, sampling_params,
LONG_PROMPT + LONG_PROMPT
"Question: what is the age of John Doe? Your answer: The age of John Doe is ", + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
) )
# Querying the age of Zack Blue # Querying the age of Zack Blue
...@@ -89,8 +93,8 @@ def main(): ...@@ -89,8 +93,8 @@ def main():
get_generation_time( get_generation_time(
llm, llm,
sampling_params, sampling_params,
LONG_PROMPT + LONG_PROMPT
"Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
) )
......
...@@ -56,22 +56,12 @@ def main(args: dict): ...@@ -56,22 +56,12 @@ def main(args: dict):
# In this script, we demonstrate how to pass input to the chat method: # In this script, we demonstrate how to pass input to the chat method:
conversation = [ conversation = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": "Hello"},
"content": "You are a helpful assistant" {"role": "assistant", "content": "Hello! How can I assist you today?"},
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{ {
"role": "user", "role": "user",
"content": "content": "Write an essay about the importance of higher education.",
"Write an essay about the importance of higher education.",
}, },
] ]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False) outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
......
...@@ -10,9 +10,9 @@ def parse_args(): ...@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", parser.set_defaults(
task="classify", model="jason9693/Qwen2.5-1.5B-apeach", task="classify", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()
...@@ -36,10 +36,11 @@ def main(args: Namespace): ...@@ -36,10 +36,11 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs probs = output.outputs.probs
probs_trimmed = ((str(probs[:16])[:-1] + probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
", ...]") if len(probs) > 16 else probs) print(
print(f"Prompt: {prompt!r} \n" f"Prompt: {prompt!r} \n"
f"Class Probabilities: {probs_trimmed} (size={len(probs)})") f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
)
print("-" * 60) print("-" * 60)
......
...@@ -10,9 +10,9 @@ def parse_args(): ...@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", parser.set_defaults(
task="embed", model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()
...@@ -36,10 +36,10 @@ def main(args: Namespace): ...@@ -36,10 +36,10 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
f"Embeddings: {embeds_trimmed} (size={len(embeds)})") print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print("-" * 60) print("-" * 60)
......
...@@ -10,9 +10,9 @@ def parse_args(): ...@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="BAAI/bge-reranker-v2-m3", parser.set_defaults(
task="score", model="BAAI/bge-reranker-v2-m3", task="score", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()
......
...@@ -17,12 +17,14 @@ Ray Data provides functionality for: ...@@ -17,12 +17,14 @@ Ray Data provides functionality for:
Learn more about Ray Data's LLM integration: Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html https://docs.ray.io/en/latest/data/working-with-llms.html
""" """
import ray import ray
from packaging.version import Version from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
assert Version(ray.__version__) >= Version( assert Version(ray.__version__) >= Version("2.44.1"), (
"2.44.1"), "Ray version must be at least 2.44.1" "Ray version must be at least 2.44.1"
)
# Uncomment to reduce clutter in stdout # Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False) # ray.init(log_to_driver=False)
...@@ -53,20 +55,18 @@ config = vLLMEngineProcessorConfig( ...@@ -53,20 +55,18 @@ config = vLLMEngineProcessorConfig(
vllm_processor = build_llm_processor( vllm_processor = build_llm_processor(
config, config,
preprocess=lambda row: dict( preprocess=lambda row: dict(
messages=[{ messages=[
"role": "system", {"role": "system", "content": "You are a bot that responds with haikus."},
"content": "You are a bot that responds with haikus." {"role": "user", "content": row["text"]},
}, { ],
"role": "user",
"content": row["text"]
}],
sampling_params=dict( sampling_params=dict(
temperature=0.3, temperature=0.3,
max_tokens=250, max_tokens=250,
)), ),
),
postprocess=lambda row: dict( postprocess=lambda row: dict(
answer=row["generated_text"], answer=row["generated_text"],
**row # This will return all the original columns in the dataset. **row, # This will return all the original columns in the dataset.
), ),
) )
......
...@@ -50,27 +50,32 @@ model_name = "mistralai/Mistral-7B-Instruct-v0.3" ...@@ -50,27 +50,32 @@ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or any other mistral model with function calling ability # or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name, llm = LLM(
model=model_name,
tokenizer_mode="mistral", tokenizer_mode="mistral",
config_format="mistral", config_format="mistral",
load_format="mistral") load_format="mistral",
)
def generate_random_id(length=9): def generate_random_id(length=9):
characters = string.ascii_letters + string.digits characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length)) random_id = "".join(random.choice(characters) for _ in range(length))
return random_id return random_id
# simulate an API that can be called # simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " return (
"partly cloudly, with highs in the 90's.") f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's."
)
tool_functions = {"get_current_weather": get_current_weather} tool_functions = {"get_current_weather": get_current_weather}
tools = [{ tools = [
{
"type": "function", "type": "function",
"function": { "function": {
"name": "get_current_weather", "name": "get_current_weather",
...@@ -79,58 +84,59 @@ tools = [{ ...@@ -79,58 +84,59 @@ tools = [{
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": "the two-letter abbreviation for the state that the city is"
"description": " in, e.g. 'CA' which would mean 'California'",
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "The unit to fetch the temperature in", "description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"] "enum": ["celsius", "fahrenheit"],
} },
},
"required": ["city", "state", "unit"],
},
}, },
"required": ["city", "state", "unit"]
}
} }
}] ]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
"Can you tell me what the temperate will be in Dallas, in fahrenheit?" }
}] ]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip() output = outputs[0].outputs[0].text.strip()
# append the assistant message # append the assistant message
messages.append({ messages.append(
{
"role": "assistant", "role": "assistant",
"content": output, "content": output,
}) }
)
# let's now actually parse and execute the model's output simulating an API call by using the # let's now actually parse and execute the model's output simulating an API call by using the
# above defined function # above defined function
tool_calls = json.loads(output) tool_calls = json.loads(output)
tool_answers = [ tool_answers = [
tool_functions[call['name']](**call['arguments']) for call in tool_calls tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
] ]
# append the answer as a tool message and let the LLM give you an answer # append the answer as a tool message and let the LLM give you an answer
messages.append({ messages.append(
{
"role": "tool", "role": "tool",
"content": "\n\n".join(tool_answers), "content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(), "tool_call_id": generate_random_id(),
}) }
)
outputs = llm.chat(messages, sampling_params, tools=tools) outputs = llm.chat(messages, sampling_params, tools=tools)
......
...@@ -27,6 +27,7 @@ Multi-node: ...@@ -27,6 +27,7 @@ Multi-node:
--master-addr=10.99.48.128 \ --master-addr=10.99.48.128 \
--master-port=13345 --master-port=13345
""" """
import os import os
from time import sleep from time import sleep
...@@ -36,46 +37,46 @@ from vllm.utils import get_open_port ...@@ -36,46 +37,46 @@ from vllm.utils import get_open_port
def parse_args(): def parse_args():
import argparse import argparse
parser = argparse.ArgumentParser(description="Data Parallel Inference") parser = argparse.ArgumentParser(description="Data Parallel Inference")
parser.add_argument("--model", parser.add_argument(
"--model",
type=str, type=str,
default="ibm-research/PowerMoE-3b", default="ibm-research/PowerMoE-3b",
help="Model name or path") help="Model name or path",
parser.add_argument("--dp-size", )
type=int, parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
default=2, parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
help="Data parallel size") parser.add_argument(
parser.add_argument("--tp-size", "--node-size", type=int, default=1, help="Total number of nodes"
type=int, )
default=2, parser.add_argument(
help="Tensor parallel size") "--node-rank", type=int, default=0, help="Rank of the current node"
parser.add_argument("--node-size", )
type=int, parser.add_argument(
default=1, "--master-addr", type=str, default="", help="Master node IP address"
help="Total number of nodes") )
parser.add_argument("--node-rank", parser.add_argument("--master-port", type=int, default=0, help="Master node port")
type=int, parser.add_argument(
default=0, "--enforce-eager", action="store_true", help="Enforce eager mode execution."
help="Rank of the current node") )
parser.add_argument("--master-addr", parser.add_argument(
type=str, "--trust-remote-code", action="store_true", help="Trust remote code."
default="", )
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args() return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, def main(
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): model,
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
GPUs_per_dp_rank,
enforce_eager,
trust_remote_code,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
...@@ -110,9 +111,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, ...@@ -110,9 +111,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
# since we are doing data parallel, every rank can have different # since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different # sampling params. here we set different max_tokens for different
# ranks for demonstration. # ranks for demonstration.
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(
top_p=0.95, temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
max_tokens=[16, 20][global_dp_rank % 2]) )
# Create an LLM. # Create an LLM.
llm = LLM( llm = LLM(
...@@ -130,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, ...@@ -130,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
break break
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " print(
f"Generated text: {generated_text!r}") f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}"
)
# Give engines time to pause their processing loops before exiting. # Give engines time to pause their processing loops before exiting.
sleep(1) sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
dp_size = args.dp_size dp_size = args.dp_size
...@@ -160,20 +162,29 @@ if __name__ == "__main__": ...@@ -160,20 +162,29 @@ if __name__ == "__main__":
procs = [] procs = []
for local_dp_rank, global_dp_rank in enumerate( for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
proc = Process(target=main, ):
args=(args.model, dp_size, local_dp_rank, proc = Process(
global_dp_rank, dp_master_ip, dp_master_port, target=main,
tp_size, args.enforce_eager, args=(
args.trust_remote_code)) args.model,
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
tp_size,
args.enforce_eager,
args.trust_remote_code,
),
)
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join(timeout=300) proc.join(timeout=300)
if proc.exitcode is None: if proc.exitcode is None:
print(f"Killing process {proc.pid} that " print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
f"didn't stop within 5 minutes.")
proc.kill() proc.kill()
exit_code = 1 exit_code = 1
elif proc.exitcode: elif proc.exitcode:
......
...@@ -22,7 +22,8 @@ def main(): ...@@ -22,7 +22,8 @@ def main():
prompts = read_prompts() prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
max_num_batched_tokens=64, max_num_batched_tokens=64,
...@@ -30,9 +31,9 @@ def main(): ...@@ -30,9 +31,9 @@ def main():
kv_transfer_config=KVTransferConfig( kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="SharedStorageConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={"shared_storage_path": "local_storage"},
"shared_storage_path": "local_storage" ),
})) #, max_model_len=2048, max_num_batched_tokens=2048) ) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance) # 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -20,15 +20,16 @@ def main(): ...@@ -20,15 +20,16 @@ def main():
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig( kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="SharedStorageConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={"shared_storage_path": "local_storage"},
"shared_storage_path": "local_storage" ),
})) #, max_model_len=2048, max_num_batched_tokens=2048) ) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance) # 1ST generation (prefill instance)
outputs = llm.generate( outputs = llm.generate(
......
...@@ -4,6 +4,7 @@ This file demonstrates the example usage of disaggregated prefilling ...@@ -4,6 +4,7 @@ This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and then transfer the KV cache between them. and then transfer the KV cache between them.
""" """
import os import os
import time import time
from multiprocessing import Event, Process from multiprocessing import Event, Process
...@@ -32,17 +33,21 @@ def run_prefill(prefill_done): ...@@ -32,17 +33,21 @@ def run_prefill(prefill_done):
# This instance is the prefill node (kv_producer, rank 0). # This instance is the prefill node (kv_producer, rank 0).
# The number of parallel instances for KV cache transfer is set to 2, # The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector. # as required for PyNcclConnector.
ktc = KVTransferConfig(kv_connector="PyNcclConnector", ktc = KVTransferConfig(
kv_connector="PyNcclConnector",
kv_role="kv_producer", kv_role="kv_producer",
kv_rank=0, kv_rank=0,
kv_parallel_size=2) kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU. # memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc, kv_transfer_config=ktc,
max_model_len=2000, max_model_len=2000,
gpu_memory_utilization=0.8) gpu_memory_utilization=0.8,
)
llm.generate(prompts, sampling_params) llm.generate(prompts, sampling_params)
print("Prefill node is finished.") print("Prefill node is finished.")
...@@ -72,17 +77,21 @@ def run_decode(prefill_done): ...@@ -72,17 +77,21 @@ def run_decode(prefill_done):
# This instance is the decode node (kv_consumer, rank 1). # This instance is the decode node (kv_consumer, rank 1).
# The number of parallel instances for KV cache transfer is set to 2, # The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector. # as required for PyNcclConnector.
ktc = KVTransferConfig(kv_connector="PyNcclConnector", ktc = KVTransferConfig(
kv_connector="PyNcclConnector",
kv_role="kv_consumer", kv_role="kv_consumer",
kv_rank=1, kv_rank=1,
kv_parallel_size=2) kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU. # memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc, kv_transfer_config=ktc,
max_model_len=2000, max_model_len=2000,
gpu_memory_utilization=0.8) gpu_memory_utilization=0.8,
)
# Wait for the producer to start the pipe # Wait for the producer to start the pipe
print("Waiting for prefill node to finish...") print("Waiting for prefill node to finish...")
...@@ -99,8 +108,8 @@ def run_decode(prefill_done): ...@@ -99,8 +108,8 @@ def run_decode(prefill_done):
def main(): def main():
prefill_done = Event() prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done, )) prefill_process = Process(target=run_prefill, args=(prefill_done,))
decode_process = Process(target=run_decode, args=(prefill_done, )) decode_process = Process(target=run_decode, args=(prefill_done,))
# Start prefill node # Start prefill node
prefill_process.start() prefill_process.start()
......
...@@ -20,9 +20,7 @@ def load_prompts(dataset_path, num_prompts): ...@@ -20,9 +20,7 @@ def load_prompts(dataset_path, num_prompts):
print(f"Error reading dataset: {e}") print(f"Error reading dataset: {e}")
return [] return []
else: else:
prompts = [ prompts = ["The future of AI is", "The president of the United States is"]
"The future of AI is", "The president of the United States is"
]
return prompts[:num_prompts] return prompts[:num_prompts]
...@@ -33,34 +31,32 @@ def parse_args(): ...@@ -33,34 +31,32 @@ def parse_args():
"--dataset", "--dataset",
type=str, type=str,
default="./examples/data/gsm8k.jsonl", default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \ help="downloaded from the eagle repo "
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
)
parser.add_argument(
"--method", type=str, default="eagle", choices=["eagle", "eagle3"]
) )
parser.add_argument("--method",
type=str,
default='eagle',
choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2) parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1) parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true') parser.add_argument("--enforce_eager", action="store_true")
parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--enable_chunked_prefill", action="store_true")
parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0) parser.add_argument("--temp", type=float, default=0)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
if args.method == 'eagle': if args.method == "eagle":
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == 'eagle3': elif args.method == "eagle3":
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else: else:
raise ValueError(f"unknown method: {args.method}") raise ValueError(f"unknown method: {args.method}")
...@@ -72,11 +68,9 @@ def main(): ...@@ -72,11 +68,9 @@ def main():
prompts = load_prompts(args.dataset, args.num_prompts) prompts = load_prompts(args.dataset, args.num_prompts)
prompt_ids = [ prompt_ids = [
tokenizer.apply_chat_template([{ tokenizer.apply_chat_template(
"role": "user", [{"role": "user", "content": prompt}], add_generation_prompt=True
"content": prompt )
}],
add_generation_prompt=True)
for prompt in prompts for prompt in prompts
] ]
...@@ -102,8 +96,7 @@ def main(): ...@@ -102,8 +96,7 @@ def main():
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
outputs = llm.generate(prompt_token_ids=prompt_ids, outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
sampling_params=sampling_params)
# print the generated text # print the generated text
for output in outputs: for output in outputs:
...@@ -120,19 +113,22 @@ def main(): ...@@ -120,19 +113,22 @@ def main():
# accepted # accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1) acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs: for output in outputs:
for step, count in enumerate( for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count acceptance_counts[step] += count
print("-" * 50) print("-" * 50)
print(f"mean acceptance length (including bonus tokens): \ print(
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}") f"mean acceptance length (including bonus tokens): \
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
)
print("-" * 50) print("-" * 50)
# print acceptance at each token position # print acceptance at each token position
for i in range(len(acceptance_counts)): for i in range(len(acceptance_counts)):
print(f"acceptance at token {i}:" print(
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") f"acceptance at token {i}:"
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,9 +10,9 @@ def parse_args(): ...@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3", parser.set_defaults(
task="embed", model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
trust_remote_code=True) )
return parser.parse_args() return parser.parse_args()
...@@ -41,11 +41,14 @@ def main(args: Namespace): ...@@ -41,11 +41,14 @@ def main(args: Namespace):
print("-" * 60) print("-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
print(
f"Prompt: {prompt!r} \n"
f"Embeddings for text matching: {embeds_trimmed} " f"Embeddings for text matching: {embeds_trimmed} "
f"(size={len(embeds)})") f"(size={len(embeds)})"
)
print("-" * 60) print("-" * 60)
......
...@@ -10,9 +10,9 @@ def parse_args(): ...@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3", parser.set_defaults(
task="embed", model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
trust_remote_code=True) )
return parser.parse_args() return parser.parse_args()
...@@ -39,11 +39,10 @@ def main(args: Namespace): ...@@ -39,11 +39,10 @@ def main(args: Namespace):
print("-" * 60) print("-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
f"Embeddings: {embeds_trimmed} " print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
f"(size={len(embeds)})")
print("-" * 60) print("-" * 60)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
''' """
Demonstrate prompting of text-to-text Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART encoder/decoder models, specifically BART
''' """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.inputs import (
TokensPrompt, zip_enc_dec_prompts) ExplicitEncoderDecoderPrompt,
TextPrompt,
TokensPrompt,
zip_enc_dec_prompts,
)
def create_prompts(tokenizer): def create_prompts(tokenizer):
...@@ -18,8 +22,9 @@ def create_prompts(tokenizer): ...@@ -18,8 +22,9 @@ def create_prompts(tokenizer):
# - Helpers for building prompts # - Helpers for building prompts
text_prompt_raw = "Hello, my name is" text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is") text_prompt = TextPrompt(prompt="The president of the United States is")
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( tokens_prompt = TokensPrompt(
prompt="The capital of France is")) prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
)
# - Pass a single prompt to encoder/decoder model # - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt); # (implicitly encoder input prompt);
# decoder input prompt is assumed to be None # decoder input prompt is assumed to be None
...@@ -57,14 +62,19 @@ def create_prompts(tokenizer): ...@@ -57,14 +62,19 @@ def create_prompts(tokenizer):
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt # decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances # instances
zipped_prompt_list = zip_enc_dec_prompts( zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'], ["An encoder prompt", "Another encoder prompt"],
['A decoder prompt', 'Another decoder prompt']) ["A decoder prompt", "Another decoder prompt"],
)
# - Let's put all of the above example prompts together into one list # - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM. # which we will pass to the encoder/decoder LLM.
return [ return [
single_text_prompt_raw, single_text_prompt, single_tokens_prompt, single_text_prompt_raw,
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 single_text_prompt,
single_tokens_prompt,
enc_dec_prompt1,
enc_dec_prompt2,
enc_dec_prompt3,
] + zipped_prompt_list ] + zipped_prompt_list
...@@ -85,10 +95,12 @@ def print_outputs(outputs): ...@@ -85,10 +95,12 @@ def print_outputs(outputs):
prompt = output.prompt prompt = output.prompt
encoder_prompt = output.encoder_prompt encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Output {i+1}:") print(f"Output {i + 1}:")
print(f"Encoder prompt: {encoder_prompt!r}\n" print(
f"Encoder prompt: {encoder_prompt!r}\n"
f"Decoder prompt: {prompt!r}\n" f"Decoder prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}"
)
print("-" * 50) print("-" * 50)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This example shows how to use vLLM for running offline inference with This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation. the explicit/implicit prompt format on enc-dec LMMs for text generation.
""" """
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict from dataclasses import asdict
...@@ -32,16 +33,12 @@ def run_florence2(): ...@@ -32,16 +33,12 @@ def run_florence2():
prompts = [ prompts = [
{ # implicit prompt with task token { # implicit prompt with task token
"prompt": "<DETAILED_CAPTION>", "prompt": "<DETAILED_CAPTION>",
"multi_modal_data": { "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
"image": ImageAsset("stop_sign").pil_image
},
}, },
{ # explicit encoder/decoder prompt { # explicit encoder/decoder prompt
"encoder_prompt": { "encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.", "prompt": "Describe in detail what is shown in the image.",
"multi_modal_data": { "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
"image": ImageAsset("cherry_blossom").pil_image
},
}, },
"decoder_prompt": "", "decoder_prompt": "",
}, },
...@@ -110,7 +107,7 @@ def run_whisper(): ...@@ -110,7 +107,7 @@ def run_whisper():
}, },
}, },
"decoder_prompt": "<|startoftranscript|>", "decoder_prompt": "<|startoftranscript|>",
} },
] ]
return ModelRequestData( return ModelRequestData(
...@@ -128,18 +125,23 @@ model_example_map = { ...@@ -128,18 +125,23 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for text generation') "vision language models for text generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
"--model-type",
"-m",
type=str, type=str,
default="mllama", default="mllama",
choices=model_example_map.keys(), choices=model_example_map.keys(),
help='Huggingface "model_type".') help='Huggingface "model_type".',
parser.add_argument("--seed", )
parser.add_argument(
"--seed",
type=int, type=int,
default=None, default=None,
help="Set the seed when initializing `vllm.LLM`.") help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
...@@ -153,7 +155,8 @@ def main(args): ...@@ -153,7 +155,8 @@ def main(args):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
...@@ -179,8 +182,7 @@ def main(args): ...@@ -179,8 +182,7 @@ def main(args):
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Decoder prompt: {prompt!r}, " print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
duration = time.time() - start duration = time.time() - start
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This file demonstrates using the `LLMEngine` This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters. for processing prompts with various sampling parameters.
""" """
import argparse import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
...@@ -12,24 +13,26 @@ from vllm.utils import FlexibleArgumentParser ...@@ -12,24 +13,26 @@ from vllm.utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]: def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters.""" """Create a list of test prompts with their sampling parameters."""
return [ return [
("A robot may not injure a human being", (
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), "A robot may not injure a human being",
("To be or not to be,", SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ),
("What is the meaning of life?", (
SamplingParams(n=2, "To be or not to be,",
temperature=0.8, SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
top_p=0.95, ),
frequency_penalty=0.1)), (
"What is the meaning of life?",
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
print('-' * 50) print("-" * 50)
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
...@@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine, ...@@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine,
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished: if request_output.finished:
print(request_output) print(request_output)
print('-' * 50) print("-" * 50)
def initialize_engine(args: argparse.Namespace) -> LLMEngine: def initialize_engine(args: argparse.Namespace) -> LLMEngine:
...@@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine: ...@@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine:
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly') description="Demo on using the LLMEngine class directly"
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
...@@ -64,6 +68,6 @@ def main(args: argparse.Namespace): ...@@ -64,6 +68,6 @@ def main(args: argparse.Namespace):
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -36,22 +36,21 @@ def parse_args(): ...@@ -36,22 +36,21 @@ def parse_args():
parser.set_defaults(load_format="sharded_state") parser.set_defaults(load_format="sharded_state")
# Add validation arguments # Add validation arguments
parser.add_argument("--prompt", parser.add_argument(
type=str, "--prompt", type=str, default="Hello, world!", help="Prompt for validation"
default="Hello, world!", )
help="Prompt for validation") parser.add_argument(
parser.add_argument("--max-tokens", "--max-tokens",
type=int, type=int,
default=100, default=100,
help="Maximum number of tokens to generate") help="Maximum number of tokens to generate",
parser.add_argument("--temperature", )
type=float, parser.add_argument(
default=0.7, "--temperature", type=float, default=0.7, help="Sampling temperature"
help="Sampling temperature") )
parser.add_argument("--top-p", parser.add_argument(
type=float, "--top-p", type=float, default=1.0, help="Top-p sampling parameter"
default=1.0, )
help="Top-p sampling parameter")
return parser.parse_args() return parser.parse_args()
...@@ -60,8 +59,9 @@ def main(): ...@@ -60,8 +59,9 @@ def main():
args = parse_args() args = parse_args()
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
print(f"Loading model from {engine_args.model} " print(
f"using format {engine_args.load_format}") f"Loading model from {engine_args.model} using format {engine_args.load_format}"
)
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}") print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
# Load the model using engine args # Load the model using engine args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment