Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
...@@ -45,12 +45,12 @@ def main(): ...@@ -45,12 +45,12 @@ def main():
# Round 2 # Round 2
messages.append({"role": "assistant", "content": content}) messages.append({"role": "assistant", "content": content})
messages.append({ messages.append(
"role": {
"user", "role": "user",
"content": "content": "How many Rs are there in the word 'strawberry'?",
"How many Rs are there in the word 'strawberry'?", }
}) )
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
......
...@@ -43,9 +43,7 @@ def main(): ...@@ -43,9 +43,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}` # For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model, stream = client.chat.completions.create(model=model, messages=messages, stream=True)
messages=messages,
stream=True)
print("client: Start streaming chat completions...") print("client: Start streaming chat completions...")
printed_reasoning_content = False printed_reasoning_content = False
......
...@@ -14,26 +14,17 @@ def vlm2vec(): ...@@ -14,26 +14,17 @@ def vlm2vec():
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
json={ json={
"model": "model": "TIGER-Lab/VLM2Vec-Full",
"TIGER-Lab/VLM2Vec-Full", "messages": [
"messages": [{ {
"role": "role": "user",
"user", "content": [
"content": [ {"type": "image_url", "image_url": {"url": image_url}},
{ {"type": "text", "text": "Represent the given image."},
"type": "image_url", ],
"image_url": { }
"url": image_url ],
} "encoding_format": "float",
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
}, },
) )
response.raise_for_status() response.raise_for_status()
...@@ -45,19 +36,20 @@ def vlm2vec(): ...@@ -45,19 +36,20 @@ def vlm2vec():
def dse_qwen2_vl(inp: dict): def dse_qwen2_vl(inp: dict):
# Embedding an Image # Embedding an Image
if inp["type"] == "image": if inp["type"] == "image":
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": inp["image_url"], "image_url": {
} "url": inp["image_url"],
}, { },
"type": "text", },
"text": "What is shown in this image?" {"type": "text", "text": "What is shown in this image?"},
}] ],
}] }
]
# Embedding a Text Query # Embedding a Text Query
else: else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
...@@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict): ...@@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict):
image_placeholder = Image.new("RGB", (56, 56)) image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png") image_placeholder.save(buffer, "png")
buffer.seek(0) buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8') image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}", "url": f"data:image/jpeg;base64,{image_placeholder}",
} },
}, },
{ {"type": "text", "text": f"Query: {inp['content']}"},
"type": "text", ],
"text": f"Query: {inp['content']}" }
}, ]
]
}]
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
...@@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict): ...@@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve " "Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embed before running this.") "the model with --task embed before running this."
parser.add_argument("--model", )
type=str, parser.add_argument(
choices=["vlm2vec", "dse_qwen2_vl"], "--model",
required=True, type=str,
help="Which model to call.") choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.",
)
return parser.parse_args() return parser.parse_args()
...@@ -114,16 +107,20 @@ def main(args): ...@@ -114,16 +107,20 @@ def main(args):
if args.model == "vlm2vec": if args.model == "vlm2vec":
vlm2vec() vlm2vec()
elif args.model == "dse_qwen2_vl": elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({ dse_qwen2_vl(
"type": "image", {
"image_url": image_url, "type": "image",
}) "image_url": image_url,
dse_qwen2_vl({ }
"type": "text", )
"content": "What is the weather like today?", dse_qwen2_vl(
}) {
"type": "text",
"content": "What is the weather like today?",
}
)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -16,9 +16,7 @@ def parse_args(): ...@@ -16,9 +16,7 @@ def parse_args():
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost") parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000) parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model", parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args() return parse.parse_args()
......
...@@ -11,9 +11,9 @@ openai_api_base = "http://localhost:8000/v1" ...@@ -11,9 +11,9 @@ openai_api_base = "http://localhost:8000/v1"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server") parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream", parser.add_argument(
action="store_true", "--stream", action="store_true", help="Enable streaming response"
help="Enable streaming response") )
return parser.parse_args() return parser.parse_args()
...@@ -34,7 +34,8 @@ def main(args): ...@@ -34,7 +34,8 @@ def main(args):
echo=False, echo=False,
n=2, n=2,
stream=args.stream, stream=args.stream,
logprobs=3) logprobs=3,
)
print("-" * 50) print("-" * 50)
print("Completion results:") print("Completion results:")
......
...@@ -4,6 +4,7 @@ Example online usage of Score API. ...@@ -4,6 +4,7 @@ Example online usage of Score API.
Run `vllm serve <model> --task score` to start up the server in vLLM. Run `vllm serve <model> --task score` to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
...@@ -38,9 +39,7 @@ def main(args): ...@@ -38,9 +39,7 @@ def main(args):
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = [ text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a list:") print("\nPrompt when text_1 is string and text_2 is a list:")
...@@ -48,12 +47,8 @@ def main(args): ...@@ -48,12 +47,8 @@ def main(args):
print("\nScore Response:") print("\nScore Response:")
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = [ text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
"What is the capital of Brazil?", "What is the capital of France?" text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
]
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both lists:") print("\nPrompt when text_1 and text_2 are both lists:")
......
...@@ -21,7 +21,7 @@ def main(): ...@@ -21,7 +21,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
input=[ input=[
"Hello my name is", "Hello my name is",
"The best thing about vLLM is that it supports many different models" "The best thing about vLLM is that it supports many different models",
], ],
model=model, model=model,
) )
......
...@@ -5,6 +5,7 @@ Example online usage of Pooling API. ...@@ -5,6 +5,7 @@ Example online usage of Pooling API.
Run `vllm serve <model> --task <embed|classify|reward|score>` Run `vllm serve <model> --task <embed|classify|reward|score>`
to start up the server in vLLM. to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
...@@ -21,9 +22,7 @@ def parse_args(): ...@@ -21,9 +22,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parser.parse_args() return parser.parse_args()
...@@ -42,15 +41,13 @@ def main(args): ...@@ -42,15 +41,13 @@ def main(args):
# Input like Chat API # Input like Chat API
prompt = { prompt = {
"model": "model": model_name,
model_name, "messages": [
"messages": [{ {
"role": "user", "role": "user",
"content": [{ "content": [{"type": "text", "text": "vLLM is great!"}],
"type": "text", }
"text": "vLLM is great!" ],
}],
}]
} }
pooling_response = post_http_request(prompt=prompt, api_url=api_url) pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("Pooling Response:") print("Pooling Response:")
......
...@@ -7,8 +7,8 @@ from openai import OpenAI ...@@ -7,8 +7,8 @@ from openai import OpenAI
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
winning_call = AudioAsset('winning_call').get_local_path() winning_call = AudioAsset("winning_call").get_local_path()
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
...@@ -31,7 +31,8 @@ def sync_openai(): ...@@ -31,7 +31,8 @@ def sync_openai():
extra_body=dict( extra_body=dict(
seed=4419, seed=4419,
repetition_penalty=1.3, repetition_penalty=1.3,
)) ),
)
print("transcription result:", transcription.text) print("transcription result:", transcription.text)
...@@ -42,33 +43,30 @@ sync_openai() ...@@ -42,33 +43,30 @@ sync_openai()
async def stream_openai_response(): async def stream_openai_response():
data = { data = {
"language": "en", "language": "en",
'stream': True, "stream": True,
"model": "openai/whisper-large-v3", "model": "openai/whisper-large-v3",
} }
url = openai_api_base + "/audio/transcriptions" url = openai_api_base + "/audio/transcriptions"
headers = {"Authorization": f"Bearer {openai_api_key}"} headers = {"Authorization": f"Bearer {openai_api_key}"}
print("transcription result:", end=' ') print("transcription result:", end=" ")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f: with open(str(winning_call), "rb") as f:
async with client.stream('POST', async with client.stream(
url, "POST", url, files={"file": f}, data=data, headers=headers
files={'file': f}, ) as response:
data=data,
headers=headers) as response:
async for line in response.aiter_lines(): async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: ' # Each line is a JSON object prefixed with 'data: '
if line: if line:
if line.startswith('data: '): if line.startswith("data: "):
line = line[len('data: '):] line = line[len("data: ") :]
# Last chunk, stream ends # Last chunk, stream ends
if line.strip() == '[DONE]': if line.strip() == "[DONE]":
break break
# Parse the JSON response # Parse the JSON response
chunk = json.loads(line) chunk = json.loads(line)
# Extract and print the content # Extract and print the content
content = chunk['choices'][0].get('delta', content = chunk["choices"][0].get("delta", {}).get("content")
{}).get('content') print(content, end="")
print(content, end='')
# Run the asynchronous function # Run the asynchronous function
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import requests import requests
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
OTLPSpanExporter)
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (BatchSpanProcessor, from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
ConsoleSpanExporter)
from opentelemetry.trace import SpanKind, set_tracer_provider from opentelemetry.trace import SpanKind, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import ( from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
TraceContextTextMapPropagator)
trace_provider = TracerProvider() trace_provider = TracerProvider()
set_tracer_provider(trace_provider) set_tracer_provider(trace_provider)
......
...@@ -26,6 +26,7 @@ Dependencies: ...@@ -26,6 +26,7 @@ Dependencies:
- torch - torch
- openai - openai
""" """
import base64 import base64
import io import io
...@@ -44,17 +45,13 @@ def main(): ...@@ -44,17 +45,13 @@ def main():
# Transformers # Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained( transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
model_name)
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
chat = [{ chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
"role": "user", token_ids = tokenizer.apply_chat_template(
"content": "Please tell me about the capital of France." chat, add_generation_prompt=True, return_tensors="pt"
}] )
token_ids = tokenizer.apply_chat_template(chat,
add_generation_prompt=True,
return_tensors='pt')
embedding_layer = transformers_model.get_input_embeddings() embedding_layer = transformers_model.get_input_embeddings()
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
...@@ -64,7 +61,7 @@ def main(): ...@@ -64,7 +61,7 @@ def main():
torch.save(prompt_embeds, buffer) torch.save(prompt_embeds, buffer)
buffer.seek(0) buffer.seek(0)
binary_data = buffer.read() binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode('utf-8') encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
completion = client.completions.create( completion = client.completions.create(
model=model_name, model=model_name,
...@@ -75,7 +72,8 @@ def main(): ...@@ -75,7 +72,8 @@ def main():
temperature=0.0, temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the # NOTE: The OpenAI client allows passing in extra JSON body via the
# `extra_body` argument. # `extra_body` argument.
extra_body={"prompt_embeds": encoded_embeds}) extra_body={"prompt_embeds": encoded_embeds},
)
print("-" * 30) print("-" * 30)
print(completion.choices[0].text) print(completion.choices[0].text)
......
...@@ -28,9 +28,7 @@ llm_config = LLMConfig( ...@@ -28,9 +28,7 @@ llm_config = LLMConfig(
}, },
# Change to the accelerator type of the node # Change to the accelerator type of the node
accelerator_type="H100", accelerator_type="H100",
runtime_env={"env_vars": { runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
"VLLM_USE_V1": "1"
}},
# Customize engine arguments as needed (e.g. vLLM engine kwargs) # Customize engine arguments as needed (e.g. vLLM engine kwargs)
engine_kwargs={ engine_kwargs={
"tensor_parallel_size": 8, "tensor_parallel_size": 8,
......
...@@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]): ...@@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]):
Load and split documents from web URL Load and split documents from web URL
""" """
try: try:
loader = WebBaseLoader(web_paths=(config["url"], )) loader = WebBaseLoader(web_paths=(config["url"],))
docs = loader.load() docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
...@@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): ...@@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
""" """
Set up question answering chain Set up question answering chain
""" """
return ({ return (
"context": retriever | format_docs, {
"question": RunnablePassthrough(), "context": retriever | format_docs,
} "question": RunnablePassthrough(),
| prompt }
| llm | prompt
| StrOutputParser()) | llm
| StrOutputParser()
)
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
""" """
Parse command line arguments Parse command line arguments
""" """
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') parser = argparse.ArgumentParser(description="RAG with vLLM and langchain")
# Add command line arguments # Add command line arguments
parser.add_argument('--vllm-api-key',
default="EMPTY",
help='API key for vLLM compatible services')
parser.add_argument('--vllm-embedding-endpoint',
default="http://localhost:8000/v1",
help='Base URL for embedding service')
parser.add_argument('--vllm-chat-endpoint',
default="http://localhost:8001/v1",
help='Base URL for chat service')
parser.add_argument('--uri',
default="./milvus.db",
help='URI for Milvus database')
parser.add_argument( parser.add_argument(
'--url', "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
default=("https://docs.vllm.ai/en/latest/getting_started/" )
"quickstart.html"), parser.add_argument(
help='URL of the document to process') "--vllm-embedding-endpoint",
parser.add_argument('--embedding-model', default="http://localhost:8000/v1",
default="ssmits/Qwen2-7B-Instruct-embed-base", help="Base URL for embedding service",
help='Model name for embeddings') )
parser.add_argument('--chat-model', parser.add_argument(
default="qwen/Qwen1.5-0.5B-Chat", "--vllm-chat-endpoint",
help='Model name for chat') default="http://localhost:8001/v1",
parser.add_argument('-i', help="Base URL for chat service",
'--interactive', )
action='store_true', parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database")
help='Enable interactive Q&A mode') parser.add_argument(
parser.add_argument('-k', "--url",
'--top-k', default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
type=int, help="URL of the document to process",
default=3, )
help='Number of top results to retrieve') parser.add_argument(
parser.add_argument('-c', "--embedding-model",
'--chunk-size', default="ssmits/Qwen2-7B-Instruct-embed-base",
type=int, help="Model name for embeddings",
default=1000, )
help='Chunk size for document splitting') parser.add_argument(
parser.add_argument('-o', "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
'--chunk-overlap', )
type=int, parser.add_argument(
default=200, "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Chunk overlap for document splitting') )
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
parser.add_argument(
"-c",
"--chunk-size",
type=int,
default=1000,
help="Chunk size for document splitting",
)
parser.add_argument(
"-o",
"--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
return parser return parser
...@@ -198,7 +205,7 @@ def init_config(args: Namespace): ...@@ -198,7 +205,7 @@ def init_config(args: Namespace):
"url": args.url, "url": args.url,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
...@@ -230,7 +237,7 @@ def main(): ...@@ -230,7 +237,7 @@ def main():
while True: while True:
question = input("\nPlease enter your question: ") question = input("\nPlease enter your question: ")
if question.lower() in ['q', 'quit']: if question.lower() in ["q", "quit"]:
print("\nThank you for using! Goodbye!") print("\nThank you for using! Goodbye!")
break break
...@@ -238,7 +245,7 @@ def main(): ...@@ -238,7 +245,7 @@ def main():
print(output) print(output)
else: else:
# Default single question mode # Default single question mode
question = ("How to install vLLM?") question = "How to install vLLM?"
output = qa_chain.invoke(question) output = qa_chain.invoke(question)
print("-" * 50) print("-" * 50)
print(output) print(output)
......
...@@ -35,6 +35,7 @@ Notes: ...@@ -35,6 +35,7 @@ Notes:
- Default ports: 8000 (embedding), 8001 (chat) - Default ports: 8000 (embedding), 8001 (chat)
- First run may take time to download models - First run may take time to download models
""" """
import argparse import argparse
from argparse import Namespace from argparse import Namespace
from typing import Any from typing import Any
...@@ -59,7 +60,7 @@ def init_config(args: Namespace): ...@@ -59,7 +60,7 @@ def init_config(args: Namespace):
"db_path": args.db_path, "db_path": args.db_path,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
...@@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int): ...@@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int):
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
"""Parse command line arguments""" """Parse command line arguments"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex")
description='RAG with vLLM and LlamaIndex')
# Add command line arguments # Add command line arguments
parser.add_argument( parser.add_argument(
'--url', "--url",
default=("https://docs.vllm.ai/en/latest/getting_started/" default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
"quickstart.html"), help="URL of the document to process",
help='URL of the document to process') )
parser.add_argument('--embedding-model', parser.add_argument(
default="ssmits/Qwen2-7B-Instruct-embed-base", "--embedding-model",
help='Model name for embeddings') default="ssmits/Qwen2-7B-Instruct-embed-base",
parser.add_argument('--chat-model', help="Model name for embeddings",
default="qwen/Qwen1.5-0.5B-Chat", )
help='Model name for chat') parser.add_argument(
parser.add_argument('--vllm-api-key', "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
default="EMPTY", )
help='API key for vLLM compatible services') parser.add_argument(
parser.add_argument('--embedding-endpoint', "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
default="http://localhost:8000/v1", )
help='Base URL for embedding service') parser.add_argument(
parser.add_argument('--chat-endpoint', "--embedding-endpoint",
default="http://localhost:8001/v1", default="http://localhost:8000/v1",
help='Base URL for chat service') help="Base URL for embedding service",
parser.add_argument('--db-path', )
default="./milvus_demo.db", parser.add_argument(
help='Path to Milvus database') "--chat-endpoint",
parser.add_argument('-i', default="http://localhost:8001/v1",
'--interactive', help="Base URL for chat service",
action='store_true', )
help='Enable interactive Q&A mode') parser.add_argument(
parser.add_argument('-c', "--db-path", default="./milvus_demo.db", help="Path to Milvus database"
'--chunk-size', )
type=int, parser.add_argument(
default=1000, "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Chunk size for document splitting') )
parser.add_argument('-o', parser.add_argument(
'--chunk-overlap', "-c",
type=int, "--chunk-size",
default=200, type=int,
help='Chunk overlap for document splitting') default=1000,
parser.add_argument('-k', help="Chunk size for document splitting",
'--top-k', )
type=int, parser.add_argument(
default=3, "-o",
help='Number of top results to retrieve') "--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
return parser return parser
...@@ -193,7 +200,7 @@ def main(): ...@@ -193,7 +200,7 @@ def main():
question = input("\nEnter your question: ") question = input("\nEnter your question: ")
# Check for exit command # Check for exit command
if question.lower() in ['quit', 'exit', 'q']: if question.lower() in ["quit", "exit", "q"]:
print("Exiting interactive mode...") print("Exiting interactive mode...")
break break
......
...@@ -26,6 +26,7 @@ Usage: ...@@ -26,6 +26,7 @@ Usage:
streamlit run streamlit_openai_chatbot_webserver.py \ streamlit run streamlit_openai_chatbot_webserver.py \
--logger.level=debug --logger.level=debug
""" """
import os import os
from datetime import datetime from datetime import datetime
...@@ -33,8 +34,8 @@ import streamlit as st ...@@ -33,8 +34,8 @@ import streamlit as st
from openai import OpenAI from openai import OpenAI
# Get command line arguments from environment variables # Get command line arguments from environment variables
openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY") openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1") openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
# Initialize session states for managing chat sessions # Initialize session states for managing chat sessions
if "sessions" not in st.session_state: if "sessions" not in st.session_state:
...@@ -81,9 +82,9 @@ def get_llm_response(messages, model): ...@@ -81,9 +82,9 @@ def get_llm_response(messages, model):
Streaming response object or error message string Streaming response object or error message string
""" """
try: try:
response = client.chat.completions.create(model=model, response = client.chat.completions.create(
messages=messages, model=model, messages=messages, stream=True
stream=True) )
return response return response
except Exception as e: except Exception as e:
st.error(f"Error details: {str(e)}") st.error(f"Error details: {str(e)}")
...@@ -92,8 +93,9 @@ def get_llm_response(messages, model): ...@@ -92,8 +93,9 @@ def get_llm_response(messages, model):
# Sidebar - API Settings first # Sidebar - API Settings first
st.sidebar.title("API Settings") st.sidebar.title("API Settings")
new_api_base = st.sidebar.text_input("API Base URL:", new_api_base = st.sidebar.text_input(
value=st.session_state.api_base_url) "API Base URL:", value=st.session_state.api_base_url
)
if new_api_base != st.session_state.api_base_url: if new_api_base != st.session_state.api_base_url:
st.session_state.api_base_url = new_api_base st.session_state.api_base_url = new_api_base
st.rerun() st.rerun()
...@@ -109,16 +111,20 @@ if st.sidebar.button("New Session"): ...@@ -109,16 +111,20 @@ if st.sidebar.button("New Session"):
for session_id in sorted(st.session_state.sessions.keys(), reverse=True): for session_id in sorted(st.session_state.sessions.keys(), reverse=True):
# Mark the active session with a pinned button # Mark the active session with a pinned button
if session_id == st.session_state.active_session: if session_id == st.session_state.active_session:
st.sidebar.button(f"📍 {session_id}", st.sidebar.button(
key=session_id, f"📍 {session_id}",
type="primary", key=session_id,
on_click=switch_to_chat_session, type="primary",
args=(session_id, )) on_click=switch_to_chat_session,
args=(session_id,),
)
else: else:
st.sidebar.button(f"Session {session_id}", st.sidebar.button(
key=session_id, f"Session {session_id}",
on_click=switch_to_chat_session, key=session_id,
args=(session_id, )) on_click=switch_to_chat_session,
args=(session_id,),
)
# Main interface # Main interface
st.title("vLLM Chat Assistant") st.title("vLLM Chat Assistant")
...@@ -145,18 +151,18 @@ for message in st.session_state.messages: ...@@ -145,18 +151,18 @@ for message in st.session_state.messages:
if prompt := st.chat_input("Type your message here..."): if prompt := st.chat_input("Type your message here..."):
# Save user message to session # Save user message to session
st.session_state.messages.append({"role": "user", "content": prompt}) st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.sessions[ st.session_state.sessions[st.session_state.current_session] = (
st.session_state.current_session] = st.session_state.messages st.session_state.messages
)
# Display user message # Display user message
with st.chat_message("user"): with st.chat_message("user"):
st.write(prompt) st.write(prompt)
# Prepare messages for llm # Prepare messages for llm
messages_for_llm = [{ messages_for_llm = [
"role": m["role"], {"role": m["role"], "content": m["content"]} for m in st.session_state.messages
"content": m["content"] ]
} for m in st.session_state.messages]
# Generate and display llm response # Generate and display llm response
with st.chat_message("assistant"): with st.chat_message("assistant"):
...@@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."): ...@@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."):
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
# Save llm response to session history # Save llm response to session history
st.session_state.messages.append({ st.session_state.messages.append({"role": "assistant", "content": full_response})
"role": "assistant",
"content": full_response
})
...@@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str: ...@@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str:
f"{client.base_url} with API key {client.api_key}. Check\n" f"{client.base_url} with API key {client.api_key}. Check\n"
"1. the server is running\n" "1. the server is running\n"
"2. the server URL is correct\n" "2. the server URL is correct\n"
"3. the API key is correct") from e "3. the API key is correct"
) from e
if len(models.data) == 0: if len(models.data) == 0:
raise RuntimeError( raise RuntimeError(f"No models found on the vLLM server at {client.base_url}")
f"No models found on the vLLM server at {client.base_url}")
return models.data[0].id return models.data[0].id
...@@ -20,6 +20,7 @@ Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1 ...@@ -20,6 +20,7 @@ Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to: Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html https://docs.lmcache.ai/getting_started/installation.html
""" """
import argparse import argparse
import contextlib import contextlib
import os import os
...@@ -49,8 +50,7 @@ def setup_environment_variables(vllm_version: str): ...@@ -49,8 +50,7 @@ def setup_environment_variables(vllm_version: str):
@contextlib.contextmanager @contextlib.contextmanager
def build_llm_with_lmcache(lmcache_connector: str, model: str, def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
vllm_version: str):
ktc = KVTransferConfig( ktc = KVTransferConfig(
kv_connector=lmcache_connector, kv_connector=lmcache_connector,
kv_role="kv_both", kv_role="kv_both",
...@@ -97,18 +97,19 @@ def print_output( ...@@ -97,18 +97,19 @@ def print_output(
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}") print(f"Generated text: {generated_text!r}")
print(f"Generation took {time.time() - start:.2f} seconds, " print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
f"{req_str} request done.")
print("-" * 50) print("-" * 50)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-v", parser.add_argument(
"--version", "-v",
choices=["v0", "v1"], "--version",
default="v1", choices=["v0", "v1"],
help="Specify vLLM version (default: v1)") default="v1",
help="Specify vLLM version (default: v1)",
)
return parser.parse_args() return parser.parse_args()
...@@ -125,7 +126,6 @@ def main(): ...@@ -125,7 +126,6 @@ def main():
setup_environment_variables(args.version) setup_environment_variables(args.version)
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
# This example script runs two requests with a shared prefix. # This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts # Define the shared prompt and specific prompts
shared_prompt = "Hello, how are you?" * 1000 shared_prompt = "Hello, how are you?" * 1000
...@@ -136,9 +136,7 @@ def main(): ...@@ -136,9 +136,7 @@ def main():
shared_prompt + "Tell me a very long story", shared_prompt + "Tell me a very long story",
] ]
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
top_p=0.95,
max_tokens=10)
# Print the first output # Print the first output
print_output(llm, first_prompt, sampling_params, "first") print_output(llm, first_prompt, sampling_params, "first")
......
...@@ -4,12 +4,13 @@ This file demonstrates the example usage of disaggregated prefilling ...@@ -4,12 +4,13 @@ This file demonstrates the example usage of disaggregated prefilling
with LMCache. with LMCache.
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and launch an additional LMCache server. and launch an additional LMCache server.
KV cache is transferred in the following manner: KV cache is transferred in the following manner:
vLLM prefill node -> LMCache server -> vLLM decode node. vLLM prefill node -> LMCache server -> vLLM decode node.
Note that `pip install lmcache` is needed to run this example. Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache. Learn more about LMCache in https://github.com/LMCache/LMCache.
""" """
import os import os
import subprocess import subprocess
import time import time
...@@ -49,19 +50,23 @@ def run_prefill(prefill_done, prompts): ...@@ -49,19 +50,23 @@ def run_prefill(prefill_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig(kv_connector="LMCacheConnector", ktc = KVTransferConfig(
kv_role="kv_producer", kv_connector="LMCacheConnector",
kv_rank=0, kv_role="kv_producer",
kv_parallel_size=2) kv_rank=0,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory. # memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
#llm.generate(prompts, sampling_params) )
# llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
...@@ -79,17 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1): ...@@ -79,17 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnector", ktc = KVTransferConfig(
kv_role="kv_consumer", kv_connector="LMCacheConnector",
kv_rank=1, kv_role="kv_consumer",
kv_parallel_size=2) kv_rank=1,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory. # of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for prefill node to finish...") print("Waiting for prefill node to finish...")
prefill_done.wait() prefill_done.wait()
...@@ -105,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1): ...@@ -105,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1):
def run_lmcache_server(port): def run_lmcache_server(port):
server_proc = subprocess.Popen([ server_proc = subprocess.Popen(
"python", "-m", "lmcache.experimental.server", "localhost", ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
str(port) )
])
return server_proc return server_proc
......
...@@ -17,13 +17,17 @@ async def lifespan(app: FastAPI): ...@@ -17,13 +17,17 @@ async def lifespan(app: FastAPI):
Lifespan context manager to handle startup and shutdown events. Lifespan context manager to handle startup and shutdown events.
""" """
# Startup: Initialize clients # Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' prefiller_base_url = (
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
)
app.state.prefill_client = httpx.AsyncClient(timeout=None, decoder_base_url = (
base_url=prefiller_base_url) f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
app.state.decode_client = httpx.AsyncClient(timeout=None, )
base_url=decoder_base_url)
app.state.prefill_client = httpx.AsyncClient(
timeout=None, base_url=prefiller_base_url
)
app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url)
yield yield
...@@ -37,7 +41,6 @@ app = FastAPI(lifespan=lifespan) ...@@ -37,7 +41,6 @@ app = FastAPI(lifespan=lifespan)
class StatsCalculator: class StatsCalculator:
def __init__(self): def __init__(self):
self._stats = [] self._stats = []
self._last_log_time = time.time() self._last_log_time = time.time()
...@@ -51,13 +54,18 @@ class StatsCalculator: ...@@ -51,13 +54,18 @@ class StatsCalculator:
def _log_stats(self): def _log_stats(self):
# Print average, median, and 99th percentile # Print average, median, and 99th percentile
np_arr = np.array(self._stats) np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \ output_str = (
"\nPrefill node TTFT stats:" + \ f"\nNum requests: {len(self._stats)}"
f"\n - Average (ms): {np.mean(np_arr)}" + \ + "\nPrefill node TTFT stats:"
f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - Average (ms): {np.mean(np_arr)}"
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + f"\n - Median (ms): {np.median(np_arr)}"
print("===============================", output_str, + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
"===============================") )
print(
"===============================",
output_str,
"===============================",
)
stats_calculator = StatsCalculator() stats_calculator = StatsCalculator()
...@@ -82,15 +90,16 @@ app.state.prefill_client = None ...@@ -82,15 +90,16 @@ app.state.prefill_client = None
app.state.decode_client = None app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, async def send_request_to_service(
req_data: dict): client: httpx.AsyncClient, endpoint: str, req_data: dict
):
""" """
Send a request to a service using a persistent client. Send a request to a service using a persistent client.
""" """
req_data = req_data.copy() req_data = req_data.copy()
req_data['max_tokens'] = 1 req_data["max_tokens"] = 1
if 'max_completion_tokens' in req_data: if "max_completion_tokens" in req_data:
req_data['max_completion_tokens'] = 1 req_data["max_completion_tokens"] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers) response = await client.post(endpoint, json=req_data, headers=headers)
...@@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, ...@@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
return response return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str, async def stream_service_response(
req_data: dict): client: httpx.AsyncClient, endpoint: str, req_data: dict
):
""" """
Asynchronously stream the response from a service using a persistent client. Asynchronously stream the response from a service using a persistent client.
""" """
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data, async with client.stream(
headers=headers) as response: "POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status() response.raise_for_status()
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
yield chunk yield chunk
...@@ -121,28 +132,28 @@ async def handle_completions(request: Request): ...@@ -121,28 +132,28 @@ async def handle_completions(request: Request):
req_data = await request.json() req_data = await request.json()
# Send request to prefill service, ignore the response # Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions", await send_request_to_service(
req_data) app.state.prefill_client, "/completions", req_data
)
et = time.time() et = time.time()
stats_calculator.add(et - st) stats_calculator.add(et - st)
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client, async for chunk in stream_service_response(
"/completions", app.state.decode_client, "/completions", req_data
req_data): ):
yield chunk yield chunk
return StreamingResponse(generate_stream(), return StreamingResponse(generate_stream(), media_type="text/event-stream")
media_type="text/event-stream")
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server" print("Error occurred in disagg prefill proxy server - completions endpoint")
" - completions endpoint")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
...@@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request): ...@@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request):
req_data = await request.json() req_data = await request.json()
# Send request to prefill service, ignore the response # Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, await send_request_to_service(
"/chat/completions", req_data) app.state.prefill_client, "/chat/completions", req_data
)
et = time.time() et = time.time()
stats_calculator.add(et - st) stats_calculator.add(et - st)
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client, async for chunk in stream_service_response(
"/chat/completions", app.state.decode_client, "/chat/completions", req_data
req_data): ):
yield chunk yield chunk
return StreamingResponse(generate_stream(), return StreamingResponse(generate_stream(), media_type="text/event-stream")
media_type="text/event-stream")
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server " print(
" - chat completions endpoint") "Error occurred in disagg prefill proxy server - chat completions endpoint"
)
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
if __name__ == '__main__': if __name__ == "__main__":
global global_args global global_args
global_args = parse_args() global_args = parse_args()
import uvicorn import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port) uvicorn.run(app, host=global_args.host, port=global_args.port)
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
This file demonstrates the example usage of remote KV cache sharing This file demonstrates the example usage of remote KV cache sharing
with LMCache. with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server. We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner: KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store). (1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). (2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example. Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache. Learn more about LMCache in https://github.com/LMCache/LMCache.
""" """
import os import os
import subprocess import subprocess
import time import time
...@@ -49,15 +50,16 @@ def run_store(store_done, prompts): ...@@ -49,15 +50,16 @@ def run_store(store_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory. # memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
...@@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1): ...@@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory. # of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for KV cache store to finish...") print("Waiting for KV cache store to finish...")
store_done.wait() store_done.wait()
...@@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1): ...@@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1):
def run_lmcache_server(port): def run_lmcache_server(port):
server_proc = subprocess.Popen([ server_proc = subprocess.Popen(
"python", "-m", "lmcache.experimental.server", "localhost", ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
str(port) )
])
return server_proc return server_proc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment