Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
......@@ -45,12 +45,12 @@ def main():
# Round 2
messages.append({"role": "assistant", "content": content})
messages.append({
"role":
"user",
"content":
"How many Rs are there in the word 'strawberry'?",
})
messages.append(
{
"role": "user",
"content": "How many Rs are there in the word 'strawberry'?",
}
)
response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content
......
......@@ -43,9 +43,7 @@ def main():
# ruff: noqa: E501
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model,
messages=messages,
stream=True)
stream = client.chat.completions.create(model=model, messages=messages, stream=True)
print("client: Start streaming chat completions...")
printed_reasoning_content = False
......
......@@ -14,26 +14,17 @@ def vlm2vec():
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
"model": "TIGER-Lab/VLM2Vec-Full",
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
}
],
"encoding_format": "float",
},
)
response.raise_for_status()
......@@ -45,19 +36,20 @@ def vlm2vec():
def dse_qwen2_vl(inp: dict):
# Embedding an Image
if inp["type"] == "image":
messages = [{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": inp["image_url"],
}
}, {
"type": "text",
"text": "What is shown in this image?"
}]
}]
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": inp["image_url"],
},
},
{"type": "text", "text": "What is shown in this image?"},
],
}
]
# Embedding a Text Query
else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
......@@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict):
image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png")
buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}",
}
},
{
"type": "text",
"text": f"Query: {inp['content']}"
},
]
}]
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}",
},
},
{"type": "text", "text": f"Query: {inp['content']}"},
],
}
]
response = requests.post(
"http://localhost:8000/v1/embeddings",
......@@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict):
def parse_args():
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embed before running this.")
parser.add_argument("--model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.")
"the model with --task embed before running this."
)
parser.add_argument(
"--model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.",
)
return parser.parse_args()
......@@ -114,16 +107,20 @@ def main(args):
if args.model == "vlm2vec":
vlm2vec()
elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({
"type": "image",
"image_url": image_url,
})
dse_qwen2_vl({
"type": "text",
"content": "What is the weather like today?",
})
dse_qwen2_vl(
{
"type": "image",
"image_url": image_url,
}
)
dse_qwen2_vl(
{
"type": "text",
"content": "What is the weather like today?",
}
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
main(args)
......@@ -16,9 +16,7 @@ def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model",
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args()
......
......@@ -11,9 +11,9 @@ openai_api_base = "http://localhost:8000/v1"
def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream",
action="store_true",
help="Enable streaming response")
parser.add_argument(
"--stream", action="store_true", help="Enable streaming response"
)
return parser.parse_args()
......@@ -34,7 +34,8 @@ def main(args):
echo=False,
n=2,
stream=args.stream,
logprobs=3)
logprobs=3,
)
print("-" * 50)
print("Completion results:")
......
......@@ -4,6 +4,7 @@ Example online usage of Score API.
Run `vllm serve <model> --task score` to start up the server in vLLM.
"""
import argparse
import pprint
......@@ -38,9 +39,7 @@ def main(args):
pprint.pprint(score_response.json())
text_1 = "What is the capital of France?"
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a list:")
......@@ -48,12 +47,8 @@ def main(args):
print("\nScore Response:")
pprint.pprint(score_response.json())
text_1 = [
"What is the capital of Brazil?", "What is the capital of France?"
]
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both lists:")
......
......@@ -21,7 +21,7 @@ def main():
# ruff: noqa: E501
input=[
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
"The best thing about vLLM is that it supports many different models",
],
model=model,
)
......
......@@ -5,6 +5,7 @@ Example online usage of Pooling API.
Run `vllm serve <model> --task <embed|classify|reward|score>`
to start up the server in vLLM.
"""
import argparse
import pprint
......@@ -21,9 +22,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model",
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
return parser.parse_args()
......@@ -42,15 +41,13 @@ def main(args):
# Input like Chat API
prompt = {
"model":
model_name,
"messages": [{
"role": "user",
"content": [{
"type": "text",
"text": "vLLM is great!"
}],
}]
"model": model_name,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "vLLM is great!"}],
}
],
}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("Pooling Response:")
......
......@@ -7,8 +7,8 @@ from openai import OpenAI
from vllm.assets.audio import AudioAsset
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path()
winning_call = AudioAsset('winning_call').get_local_path()
mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
winning_call = AudioAsset("winning_call").get_local_path()
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
......@@ -31,7 +31,8 @@ def sync_openai():
extra_body=dict(
seed=4419,
repetition_penalty=1.3,
))
),
)
print("transcription result:", transcription.text)
......@@ -42,33 +43,30 @@ sync_openai()
async def stream_openai_response():
data = {
"language": "en",
'stream': True,
"stream": True,
"model": "openai/whisper-large-v3",
}
url = openai_api_base + "/audio/transcriptions"
headers = {"Authorization": f"Bearer {openai_api_key}"}
print("transcription result:", end=' ')
print("transcription result:", end=" ")
async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f:
async with client.stream('POST',
url,
files={'file': f},
data=data,
headers=headers) as response:
async with client.stream(
"POST", url, files={"file": f}, data=data, headers=headers
) as response:
async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: '
if line:
if line.startswith('data: '):
line = line[len('data: '):]
if line.startswith("data: "):
line = line[len("data: ") :]
# Last chunk, stream ends
if line.strip() == '[DONE]':
if line.strip() == "[DONE]":
break
# Parse the JSON response
chunk = json.loads(line)
# Extract and print the content
content = chunk['choices'][0].get('delta',
{}).get('content')
print(content, end='')
content = chunk["choices"][0].get("delta", {}).get("content")
print(content, end="")
# Run the asynchronous function
......
# SPDX-License-Identifier: Apache-2.0
import requests
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (BatchSpanProcessor,
ConsoleSpanExporter)
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.trace import SpanKind, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator)
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
trace_provider = TracerProvider()
set_tracer_provider(trace_provider)
......
......@@ -26,6 +26,7 @@ Dependencies:
- torch
- openai
"""
import base64
import io
......@@ -44,17 +45,13 @@ def main():
# Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(
model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
# Refer to the HuggingFace repo for the correct format to use
chat = [{
"role": "user",
"content": "Please tell me about the capital of France."
}]
token_ids = tokenizer.apply_chat_template(chat,
add_generation_prompt=True,
return_tensors='pt')
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
embedding_layer = transformers_model.get_input_embeddings()
prompt_embeds = embedding_layer(token_ids).squeeze(0)
......@@ -64,7 +61,7 @@ def main():
torch.save(prompt_embeds, buffer)
buffer.seek(0)
binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
completion = client.completions.create(
model=model_name,
......@@ -75,7 +72,8 @@ def main():
temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the
# `extra_body` argument.
extra_body={"prompt_embeds": encoded_embeds})
extra_body={"prompt_embeds": encoded_embeds},
)
print("-" * 30)
print(completion.choices[0].text)
......
......@@ -28,9 +28,7 @@ llm_config = LLMConfig(
},
# Change to the accelerator type of the node
accelerator_type="H100",
runtime_env={"env_vars": {
"VLLM_USE_V1": "1"
}},
runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
# Customize engine arguments as needed (e.g. vLLM engine kwargs)
engine_kwargs={
"tensor_parallel_size": 8,
......
......@@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]):
Load and split documents from web URL
"""
try:
loader = WebBaseLoader(web_paths=(config["url"], ))
loader = WebBaseLoader(web_paths=(config["url"],))
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
......@@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
"""
Set up question answering chain
"""
return ({
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser())
return (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_parser() -> argparse.ArgumentParser:
"""
Parse command line arguments
"""
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain')
parser = argparse.ArgumentParser(description="RAG with vLLM and langchain")
# Add command line arguments
parser.add_argument('--vllm-api-key',
default="EMPTY",
help='API key for vLLM compatible services')
parser.add_argument('--vllm-embedding-endpoint',
default="http://localhost:8000/v1",
help='Base URL for embedding service')
parser.add_argument('--vllm-chat-endpoint',
default="http://localhost:8001/v1",
help='Base URL for chat service')
parser.add_argument('--uri',
default="./milvus.db",
help='URI for Milvus database')
parser.add_argument(
'--url',
default=("https://docs.vllm.ai/en/latest/getting_started/"
"quickstart.html"),
help='URL of the document to process')
parser.add_argument('--embedding-model',
default="ssmits/Qwen2-7B-Instruct-embed-base",
help='Model name for embeddings')
parser.add_argument('--chat-model',
default="qwen/Qwen1.5-0.5B-Chat",
help='Model name for chat')
parser.add_argument('-i',
'--interactive',
action='store_true',
help='Enable interactive Q&A mode')
parser.add_argument('-k',
'--top-k',
type=int,
default=3,
help='Number of top results to retrieve')
parser.add_argument('-c',
'--chunk-size',
type=int,
default=1000,
help='Chunk size for document splitting')
parser.add_argument('-o',
'--chunk-overlap',
type=int,
default=200,
help='Chunk overlap for document splitting')
"--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
)
parser.add_argument(
"--vllm-embedding-endpoint",
default="http://localhost:8000/v1",
help="Base URL for embedding service",
)
parser.add_argument(
"--vllm-chat-endpoint",
default="http://localhost:8001/v1",
help="Base URL for chat service",
)
parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database")
parser.add_argument(
"--url",
default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
help="URL of the document to process",
)
parser.add_argument(
"--embedding-model",
default="ssmits/Qwen2-7B-Instruct-embed-base",
help="Model name for embeddings",
)
parser.add_argument(
"--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
)
parser.add_argument(
"-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
)
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
parser.add_argument(
"-c",
"--chunk-size",
type=int,
default=1000,
help="Chunk size for document splitting",
)
parser.add_argument(
"-o",
"--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
return parser
......@@ -198,7 +205,7 @@ def init_config(args: Namespace):
"url": args.url,
"chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap,
"top_k": args.top_k
"top_k": args.top_k,
}
......@@ -230,7 +237,7 @@ def main():
while True:
question = input("\nPlease enter your question: ")
if question.lower() in ['q', 'quit']:
if question.lower() in ["q", "quit"]:
print("\nThank you for using! Goodbye!")
break
......@@ -238,7 +245,7 @@ def main():
print(output)
else:
# Default single question mode
question = ("How to install vLLM?")
question = "How to install vLLM?"
output = qa_chain.invoke(question)
print("-" * 50)
print(output)
......
......@@ -35,6 +35,7 @@ Notes:
- Default ports: 8000 (embedding), 8001 (chat)
- First run may take time to download models
"""
import argparse
from argparse import Namespace
from typing import Any
......@@ -59,7 +60,7 @@ def init_config(args: Namespace):
"db_path": args.db_path,
"chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap,
"top_k": args.top_k
"top_k": args.top_k,
}
......@@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int):
def get_parser() -> argparse.ArgumentParser:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='RAG with vLLM and LlamaIndex')
parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex")
# Add command line arguments
parser.add_argument(
'--url',
default=("https://docs.vllm.ai/en/latest/getting_started/"
"quickstart.html"),
help='URL of the document to process')
parser.add_argument('--embedding-model',
default="ssmits/Qwen2-7B-Instruct-embed-base",
help='Model name for embeddings')
parser.add_argument('--chat-model',
default="qwen/Qwen1.5-0.5B-Chat",
help='Model name for chat')
parser.add_argument('--vllm-api-key',
default="EMPTY",
help='API key for vLLM compatible services')
parser.add_argument('--embedding-endpoint',
default="http://localhost:8000/v1",
help='Base URL for embedding service')
parser.add_argument('--chat-endpoint',
default="http://localhost:8001/v1",
help='Base URL for chat service')
parser.add_argument('--db-path',
default="./milvus_demo.db",
help='Path to Milvus database')
parser.add_argument('-i',
'--interactive',
action='store_true',
help='Enable interactive Q&A mode')
parser.add_argument('-c',
'--chunk-size',
type=int,
default=1000,
help='Chunk size for document splitting')
parser.add_argument('-o',
'--chunk-overlap',
type=int,
default=200,
help='Chunk overlap for document splitting')
parser.add_argument('-k',
'--top-k',
type=int,
default=3,
help='Number of top results to retrieve')
"--url",
default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
help="URL of the document to process",
)
parser.add_argument(
"--embedding-model",
default="ssmits/Qwen2-7B-Instruct-embed-base",
help="Model name for embeddings",
)
parser.add_argument(
"--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
)
parser.add_argument(
"--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
)
parser.add_argument(
"--embedding-endpoint",
default="http://localhost:8000/v1",
help="Base URL for embedding service",
)
parser.add_argument(
"--chat-endpoint",
default="http://localhost:8001/v1",
help="Base URL for chat service",
)
parser.add_argument(
"--db-path", default="./milvus_demo.db", help="Path to Milvus database"
)
parser.add_argument(
"-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
)
parser.add_argument(
"-c",
"--chunk-size",
type=int,
default=1000,
help="Chunk size for document splitting",
)
parser.add_argument(
"-o",
"--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
return parser
......@@ -193,7 +200,7 @@ def main():
question = input("\nEnter your question: ")
# Check for exit command
if question.lower() in ['quit', 'exit', 'q']:
if question.lower() in ["quit", "exit", "q"]:
print("Exiting interactive mode...")
break
......
......@@ -26,6 +26,7 @@ Usage:
streamlit run streamlit_openai_chatbot_webserver.py \
--logger.level=debug
"""
import os
from datetime import datetime
......@@ -33,8 +34,8 @@ import streamlit as st
from openai import OpenAI
# Get command line arguments from environment variables
openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY")
openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1")
openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
# Initialize session states for managing chat sessions
if "sessions" not in st.session_state:
......@@ -81,9 +82,9 @@ def get_llm_response(messages, model):
Streaming response object or error message string
"""
try:
response = client.chat.completions.create(model=model,
messages=messages,
stream=True)
response = client.chat.completions.create(
model=model, messages=messages, stream=True
)
return response
except Exception as e:
st.error(f"Error details: {str(e)}")
......@@ -92,8 +93,9 @@ def get_llm_response(messages, model):
# Sidebar - API Settings first
st.sidebar.title("API Settings")
new_api_base = st.sidebar.text_input("API Base URL:",
value=st.session_state.api_base_url)
new_api_base = st.sidebar.text_input(
"API Base URL:", value=st.session_state.api_base_url
)
if new_api_base != st.session_state.api_base_url:
st.session_state.api_base_url = new_api_base
st.rerun()
......@@ -109,16 +111,20 @@ if st.sidebar.button("New Session"):
for session_id in sorted(st.session_state.sessions.keys(), reverse=True):
# Mark the active session with a pinned button
if session_id == st.session_state.active_session:
st.sidebar.button(f"📍 {session_id}",
key=session_id,
type="primary",
on_click=switch_to_chat_session,
args=(session_id, ))
st.sidebar.button(
f"📍 {session_id}",
key=session_id,
type="primary",
on_click=switch_to_chat_session,
args=(session_id,),
)
else:
st.sidebar.button(f"Session {session_id}",
key=session_id,
on_click=switch_to_chat_session,
args=(session_id, ))
st.sidebar.button(
f"Session {session_id}",
key=session_id,
on_click=switch_to_chat_session,
args=(session_id,),
)
# Main interface
st.title("vLLM Chat Assistant")
......@@ -145,18 +151,18 @@ for message in st.session_state.messages:
if prompt := st.chat_input("Type your message here..."):
# Save user message to session
st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.sessions[
st.session_state.current_session] = st.session_state.messages
st.session_state.sessions[st.session_state.current_session] = (
st.session_state.messages
)
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Prepare messages for llm
messages_for_llm = [{
"role": m["role"],
"content": m["content"]
} for m in st.session_state.messages]
messages_for_llm = [
{"role": m["role"], "content": m["content"]} for m in st.session_state.messages
]
# Generate and display llm response
with st.chat_message("assistant"):
......@@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."):
message_placeholder.markdown(full_response)
# Save llm response to session history
st.session_state.messages.append({
"role": "assistant",
"content": full_response
})
st.session_state.messages.append({"role": "assistant", "content": full_response})
......@@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str:
f"{client.base_url} with API key {client.api_key}. Check\n"
"1. the server is running\n"
"2. the server URL is correct\n"
"3. the API key is correct") from e
"3. the API key is correct"
) from e
if len(models.data) == 0:
raise RuntimeError(
f"No models found on the vLLM server at {client.base_url}")
raise RuntimeError(f"No models found on the vLLM server at {client.base_url}")
return models.data[0].id
......@@ -20,6 +20,7 @@ Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
"""
import argparse
import contextlib
import os
......@@ -49,8 +50,7 @@ def setup_environment_variables(vllm_version: str):
@contextlib.contextmanager
def build_llm_with_lmcache(lmcache_connector: str, model: str,
vllm_version: str):
def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
ktc = KVTransferConfig(
kv_connector=lmcache_connector,
kv_role="kv_both",
......@@ -97,18 +97,19 @@ def print_output(
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print(f"Generation took {time.time() - start:.2f} seconds, "
f"{req_str} request done.")
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
print("-" * 50)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-v",
"--version",
choices=["v0", "v1"],
default="v1",
help="Specify vLLM version (default: v1)")
parser.add_argument(
"-v",
"--version",
choices=["v0", "v1"],
default="v1",
help="Specify vLLM version (default: v1)",
)
return parser.parse_args()
......@@ -125,7 +126,6 @@ def main():
setup_environment_variables(args.version)
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
# This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts
shared_prompt = "Hello, how are you?" * 1000
......@@ -136,9 +136,7 @@ def main():
shared_prompt + "Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0,
top_p=0.95,
max_tokens=10)
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
# Print the first output
print_output(llm, first_prompt, sampling_params, "first")
......
......@@ -4,12 +4,13 @@ This file demonstrates the example usage of disaggregated prefilling
with LMCache.
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and launch an additional LMCache server.
KV cache is transferred in the following manner:
KV cache is transferred in the following manner:
vLLM prefill node -> LMCache server -> vLLM decode node.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
......@@ -49,19 +50,23 @@ def run_prefill(prefill_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
kv_role="kv_producer",
kv_rank=0,
kv_parallel_size=2)
ktc = KVTransferConfig(
kv_connector="LMCacheConnector",
kv_role="kv_producer",
kv_rank=0,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
#llm.generate(prompts, sampling_params)
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True,
)
# llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
......@@ -79,17 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
kv_role="kv_consumer",
kv_rank=1,
kv_parallel_size=2)
ktc = KVTransferConfig(
kv_connector="LMCacheConnector",
kv_role="kv_consumer",
kv_rank=1,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for prefill node to finish...")
prefill_done.wait()
......@@ -105,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1):
def run_lmcache_server(port):
server_proc = subprocess.Popen([
"python", "-m", "lmcache.experimental.server", "localhost",
str(port)
])
server_proc = subprocess.Popen(
["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
)
return server_proc
......
......@@ -17,13 +17,17 @@ async def lifespan(app: FastAPI):
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
app.state.prefill_client = httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url)
app.state.decode_client = httpx.AsyncClient(timeout=None,
base_url=decoder_base_url)
prefiller_base_url = (
f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
)
decoder_base_url = (
f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
)
app.state.prefill_client = httpx.AsyncClient(
timeout=None, base_url=prefiller_base_url
)
app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url)
yield
......@@ -37,7 +41,6 @@ app = FastAPI(lifespan=lifespan)
class StatsCalculator:
def __init__(self):
self._stats = []
self._last_log_time = time.time()
......@@ -51,13 +54,18 @@ class StatsCalculator:
def _log_stats(self):
# Print average, median, and 99th percentile
np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \
"\nPrefill node TTFT stats:" + \
f"\n - Average (ms): {np.mean(np_arr)}" + \
f"\n - Median (ms): {np.median(np_arr)}" + \
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
print("===============================", output_str,
"===============================")
output_str = (
f"\nNum requests: {len(self._stats)}"
+ "\nPrefill node TTFT stats:"
+ f"\n - Average (ms): {np.mean(np_arr)}"
+ f"\n - Median (ms): {np.median(np_arr)}"
+ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
)
print(
"===============================",
output_str,
"===============================",
)
stats_calculator = StatsCalculator()
......@@ -82,15 +90,16 @@ app.state.prefill_client = None
app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
async def send_request_to_service(
client: httpx.AsyncClient, endpoint: str, req_data: dict
):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['max_tokens'] = 1
if 'max_completion_tokens' in req_data:
req_data['max_completion_tokens'] = 1
req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
......@@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
async def stream_service_response(
client: httpx.AsyncClient, endpoint: str, req_data: dict
):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
async with client.stream(
"POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
......@@ -121,28 +132,28 @@ async def handle_completions(request: Request):
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions",
req_data)
await send_request_to_service(
app.state.prefill_client, "/completions", req_data
)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/completions",
req_data):
async for chunk in stream_service_response(
app.state.decode_client, "/completions", req_data
):
yield chunk
return StreamingResponse(generate_stream(),
media_type="text/event-stream")
return StreamingResponse(generate_stream(), media_type="text/event-stream")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print("Error occurred in disagg prefill proxy server - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
......@@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request):
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client,
"/chat/completions", req_data)
await send_request_to_service(
app.state.prefill_client, "/chat/completions", req_data
)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/chat/completions",
req_data):
async for chunk in stream_service_response(
app.state.decode_client, "/chat/completions", req_data
):
yield chunk
return StreamingResponse(generate_stream(),
media_type="text/event-stream")
return StreamingResponse(generate_stream(), media_type="text/event-stream")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server "
" - chat completions endpoint")
print(
"Error occurred in disagg prefill proxy server - chat completions endpoint"
)
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
......@@ -3,13 +3,14 @@
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner:
KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
......@@ -49,15 +50,16 @@ def run_store(store_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
kv_role="kv_both")
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
......@@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
kv_role="kv_both")
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for KV cache store to finish...")
store_done.wait()
......@@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1):
def run_lmcache_server(port):
server_proc = subprocess.Popen([
"python", "-m", "lmcache.experimental.server", "localhost",
str(port)
])
server_proc = subprocess.Popen(
["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
)
return server_proc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment