Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -12,15 +12,17 @@ from enum import Enum ...@@ -12,15 +12,17 @@ from enum import Enum
from openai import BadRequestError, OpenAI from openai import BadRequestError, OpenAI
from pydantic import BaseModel from pydantic import BaseModel
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
# Guided decoding by Choice (list of possible options) # Guided decoding by Choice (list of possible options)
def guided_choice_completion(client: OpenAI, model: str): def guided_choice_completion(client: OpenAI, model: str):
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
"content": "Classify this sentiment: vLLM is wonderful!" ],
}],
extra_body={"guided_choice": ["positive", "negative"]}, extra_body={"guided_choice": ["positive", "negative"]},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -28,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str): ...@@ -28,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str):
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:" "End in .com and new line. Example result:"
"alan.turing@enigma.com\n") "alan.turing@enigma.com\n"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
extra_body={ ],
"guided_regex": r"\w+@\w+\.com\n", extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
"stop": ["\n"]
},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -63,14 +66,18 @@ class CarDescription(BaseModel): ...@@ -63,14 +66,18 @@ class CarDescription(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -92,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -92,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -107,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -107,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str):
# Extra backend options # Extra backend options
def extra_backend_options_completion(client: OpenAI, model: str): def extra_backend_options_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:" "End in .com and new line. Example result:"
"alan.turing@enigma.com\n") "alan.turing@enigma.com\n"
)
try: try:
# The guided_decoding_disable_fallback option forces vLLM to use # The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why # xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={ extra_body={
"guided_regex": r"\w+@\w+\.com\n", "guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"], "stop": ["\n"],
...@@ -134,8 +149,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): ...@@ -134,8 +149,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
def main(): def main():
client: OpenAI = OpenAI( client: OpenAI = OpenAI(
base_url="http://localhost:8000/v1", base_url=openai_api_base,
api_key="-", api_key=openai_api_key,
) )
model = client.models.list().data[0].id model = client.models.list().data[0].id
......
...@@ -7,18 +7,20 @@ from openai import OpenAI ...@@ -7,18 +7,20 @@ from openai import OpenAI
# to enforce the format of a tool call response, but it could be used for # to enforce the format of a tool call response, but it could be used for
# any structured output within a subset of the response. # any structured output within a subset of the response.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
def main(): def main():
client = OpenAI( client = OpenAI(
base_url="http://localhost:8000/v1", base_url=openai_api_base,
api_key="-", api_key=openai_api_key,
) )
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": "content": """
"""
You have access to the following function to retrieve the weather in a city: You have access to the following function to retrieve the weather in a city:
{ {
...@@ -55,29 +57,28 @@ You are a helpful assistant. ...@@ -55,29 +57,28 @@ You are a helpful assistant.
Given the previous instructions, what is the weather in New York City, Boston, Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco? and San Francisco?
""" """,
}] }
]
response = client.chat.completions.create( response = client.chat.completions.create(
model=client.models.list().data[0].id, model=client.models.list().data[0].id,
messages=messages, messages=messages,
response_format={ response_format={
"type": "type": "structural_tag",
"structural_tag", "structures": [
"structures": [{ {
"begin": "<function=get_weather>", "begin": "<function=get_weather>",
"schema": { "schema": {
"type": "object", "type": "object",
"properties": { "properties": {"city": {"type": "string"}},
"city": { },
"type": "string" "end": "</function>",
}
} }
],
"triggers": ["<function="],
}, },
"end": "</function>" )
}],
"triggers": ["<function="]
})
print(response) print(response)
......
...@@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1" ...@@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1"
def print_completion_details(completion): def print_completion_details(completion):
print("reasoning_content: ", print("reasoning_content: ", completion.choices[0].message.reasoning_content)
completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content) print("content: ", completion.choices[0].message.content)
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("What is the capital of France?") prompt = "What is the capital of France?"
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={ extra_body={
"guided_regex": "(Paris|London)", "guided_regex": "(Paris|London)",
}, },
...@@ -57,13 +58,15 @@ class People(BaseModel): ...@@ -57,13 +58,15 @@ class People(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = People.model_json_schema() json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.") prompt = "Generate a JSON with the name and age of one random person."
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
...@@ -86,14 +89,18 @@ class CarDescription(BaseModel): ...@@ -86,14 +89,18 @@ class CarDescription(BaseModel):
def guided_car_json_completion(client: OpenAI, model: str): def guided_car_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
...@@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
""" """
# This may be very slow https://github.com/vllm-project/vllm/issues/12122 # This may be very slow https://github.com/vllm-project/vllm/issues/12122
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
print_completion_details(completion) print_completion_details(completion)
......
...@@ -20,9 +20,11 @@ from openai import OpenAI ...@@ -20,9 +20,11 @@ from openai import OpenAI
# Now, simulate a tool call # Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " return (
"partly cloudly, with highs in the 90's.") "The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather} available_tools = {"get_current_weather": get_current_weather}
...@@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather} ...@@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather}
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
tools = [{ properties = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": "the two-letter abbreviation for the state that the city is"
"description": " in, e.g. 'CA' which would mean 'California'",
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "The unit to fetch the temperature in", "description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"] "enum": ["celsius", "fahrenheit"],
} },
}
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": properties,
"required": ["city", "state", "unit"],
},
}, },
"required": ["city", "state", "unit"]
}
} }
}] ]
messages = [{ messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{
"role": "user", "role": "user",
"content": "Hi! How are you doing today?" "content": (
}, {
"role": "assistant",
"content": "I'm doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?" "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}] ),
},
]
def extract_reasoning_and_calls(chunks: list): def extract_reasoning_and_calls(chunks: list):
...@@ -110,73 +110,55 @@ def main(): ...@@ -110,73 +110,55 @@ def main():
models = client.models.list() models = client.models.list()
model = models.data[0].id model = models.data[0].id
print("---------Full Generate With Automatic Function Calling-------------")
tool_calls = client.chat.completions.create(
messages=messages, model=model, tools=tools
)
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
print( print(
"---------Full Generate With Automatic Function Calling-------------") f"function arguments: "
tool_calls = client.chat.completions.create(messages=messages, f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}"
model=model,
tools=tools)
print(
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
) )
print(f"function name: "
f"{tool_calls.choices[0].message.tool_calls[0].function.name}")
print(f"function arguments: "
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
print( print("----------Stream Generate With Automatic Function Calling-----------")
"----------Stream Generate With Automatic Function Calling-----------") tool_calls_stream = client.chat.completions.create(
tool_calls_stream = client.chat.completions.create(messages=messages, messages=messages, model=model, tools=tools, stream=True
model=model, )
tools=tools,
stream=True)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")
print( print("----------Full Generate With Named Function Calling-----------------")
"----------Full Generate With Named Function Calling-----------------") tool_calls = client.chat.completions.create(
tool_calls = client.chat.completions.create(messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", )
"function": {
"name":
"get_current_weather"
}
})
tool_call = tool_calls.choices[0].message.tool_calls[0].function tool_call = tool_calls.choices[0].message.tool_calls[0].function
print( print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
)
print(f"function name: {tool_call.name}") print(f"function name: {tool_call.name}")
print(f"function arguments: {tool_call.arguments}") print(f"function arguments: {tool_call.arguments}")
print( print("----------Stream Generate With Named Function Calling--------------")
"----------Stream Generate With Named Function Calling--------------")
tool_calls_stream = client.chat.completions.create( tool_calls_stream = client.chat.completions.create(
messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", stream=True,
"function": { )
"name": "get_current_weather"
}
},
stream=True)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")
......
...@@ -45,12 +45,12 @@ def main(): ...@@ -45,12 +45,12 @@ def main():
# Round 2 # Round 2
messages.append({"role": "assistant", "content": content}) messages.append({"role": "assistant", "content": content})
messages.append({ messages.append(
"role": {
"user", "role": "user",
"content": "content": "How many Rs are there in the word 'strawberry'?",
"How many Rs are there in the word 'strawberry'?", }
}) )
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
......
...@@ -43,9 +43,7 @@ def main(): ...@@ -43,9 +43,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}` # For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model, stream = client.chat.completions.create(model=model, messages=messages, stream=True)
messages=messages,
stream=True)
print("client: Start streaming chat completions...") print("client: Start streaming chat completions...")
printed_reasoning_content = False printed_reasoning_content = False
......
...@@ -14,26 +14,17 @@ def vlm2vec(): ...@@ -14,26 +14,17 @@ def vlm2vec():
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
json={ json={
"model": "model": "TIGER-Lab/VLM2Vec-Full",
"TIGER-Lab/VLM2Vec-Full", "messages": [
"messages": [{
"role":
"user",
"content": [
{ {
"type": "image_url", "role": "user",
"image_url": { "content": [
"url": image_url {"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
} }
},
{
"type": "text",
"text": "Represent the given image."
},
], ],
}], "encoding_format": "float",
"encoding_format":
"float",
}, },
) )
response.raise_for_status() response.raise_for_status()
...@@ -45,19 +36,20 @@ def vlm2vec(): ...@@ -45,19 +36,20 @@ def vlm2vec():
def dse_qwen2_vl(inp: dict): def dse_qwen2_vl(inp: dict):
# Embedding an Image # Embedding an Image
if inp["type"] == "image": if inp["type"] == "image":
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [{ "content": [
{
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": inp["image_url"], "url": inp["image_url"],
},
},
{"type": "text", "text": "What is shown in this image?"},
],
} }
}, { ]
"type": "text",
"text": "What is shown in this image?"
}]
}]
# Embedding a Text Query # Embedding a Text Query
else: else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
...@@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict): ...@@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict):
image_placeholder = Image.new("RGB", (56, 56)) image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png") image_placeholder.save(buffer, "png")
buffer.seek(0) buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8') image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}", "url": f"data:image/jpeg;base64,{image_placeholder}",
}
}, },
{
"type": "text",
"text": f"Query: {inp['content']}"
}, },
{"type": "text", "text": f"Query: {inp['content']}"},
],
}
] ]
}]
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
...@@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict): ...@@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve " "Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embed before running this.") "the model with --task embed before running this."
parser.add_argument("--model", )
parser.add_argument(
"--model",
type=str, type=str,
choices=["vlm2vec", "dse_qwen2_vl"], choices=["vlm2vec", "dse_qwen2_vl"],
required=True, required=True,
help="Which model to call.") help="Which model to call.",
)
return parser.parse_args() return parser.parse_args()
...@@ -114,16 +107,20 @@ def main(args): ...@@ -114,16 +107,20 @@ def main(args):
if args.model == "vlm2vec": if args.model == "vlm2vec":
vlm2vec() vlm2vec()
elif args.model == "dse_qwen2_vl": elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({ dse_qwen2_vl(
{
"type": "image", "type": "image",
"image_url": image_url, "image_url": image_url,
}) }
dse_qwen2_vl({ )
dse_qwen2_vl(
{
"type": "text", "type": "text",
"content": "What is the weather like today?", "content": "What is the weather like today?",
}) }
)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -16,9 +16,7 @@ def parse_args(): ...@@ -16,9 +16,7 @@ def parse_args():
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost") parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000) parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model", parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args() return parse.parse_args()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse
from openai import OpenAI from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
...@@ -7,7 +9,15 @@ openai_api_key = "EMPTY" ...@@ -7,7 +9,15 @@ openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
def main(): def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument(
"--stream", action="store_true", help="Enable streaming response"
)
return parser.parse_args()
def main(args):
client = OpenAI( client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY") # defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key, api_key=openai_api_key,
...@@ -18,18 +28,18 @@ def main(): ...@@ -18,18 +28,18 @@ def main():
model = models.data[0].id model = models.data[0].id
# Completion API # Completion API
stream = False
completion = client.completions.create( completion = client.completions.create(
model=model, model=model,
prompt="A robot may not injure a human being", prompt="A robot may not injure a human being",
echo=False, echo=False,
n=2, n=2,
stream=stream, stream=args.stream,
logprobs=3) logprobs=3,
)
print("-" * 50) print("-" * 50)
print("Completion results:") print("Completion results:")
if stream: if args.stream:
for c in completion: for c in completion:
print(c) print(c)
else: else:
...@@ -38,4 +48,5 @@ def main(): ...@@ -38,4 +48,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() args = parse_args()
main(args)
...@@ -4,6 +4,7 @@ Example online usage of Score API. ...@@ -4,6 +4,7 @@ Example online usage of Score API.
Run `vllm serve <model> --task score` to start up the server in vLLM. Run `vllm serve <model> --task score` to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
...@@ -38,9 +39,7 @@ def main(args): ...@@ -38,9 +39,7 @@ def main(args):
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = [ text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a list:") print("\nPrompt when text_1 is string and text_2 is a list:")
...@@ -48,12 +47,8 @@ def main(args): ...@@ -48,12 +47,8 @@ def main(args):
print("\nScore Response:") print("\nScore Response:")
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = [ text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
"What is the capital of Brazil?", "What is the capital of France?" text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
]
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both lists:") print("\nPrompt when text_1 and text_2 are both lists:")
......
...@@ -21,7 +21,7 @@ def main(): ...@@ -21,7 +21,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
input=[ input=[
"Hello my name is", "Hello my name is",
"The best thing about vLLM is that it supports many different models" "The best thing about vLLM is that it supports many different models",
], ],
model=model, model=model,
) )
......
...@@ -5,6 +5,7 @@ Example online usage of Pooling API. ...@@ -5,6 +5,7 @@ Example online usage of Pooling API.
Run `vllm serve <model> --task <embed|classify|reward|score>` Run `vllm serve <model> --task <embed|classify|reward|score>`
to start up the server in vLLM. to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
...@@ -21,9 +22,7 @@ def parse_args(): ...@@ -21,9 +22,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parser.parse_args() return parser.parse_args()
...@@ -42,15 +41,13 @@ def main(args): ...@@ -42,15 +41,13 @@ def main(args):
# Input like Chat API # Input like Chat API
prompt = { prompt = {
"model": "model": model_name,
model_name, "messages": [
"messages": [{ {
"role": "user", "role": "user",
"content": [{ "content": [{"type": "text", "text": "vLLM is great!"}],
"type": "text", }
"text": "vLLM is great!" ],
}],
}]
} }
pooling_response = post_http_request(prompt=prompt, api_url=api_url) pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("Pooling Response:") print("Pooling Response:")
......
...@@ -7,8 +7,8 @@ from openai import OpenAI ...@@ -7,8 +7,8 @@ from openai import OpenAI
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
winning_call = AudioAsset('winning_call').get_local_path() winning_call = AudioAsset("winning_call").get_local_path()
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
...@@ -31,7 +31,8 @@ def sync_openai(): ...@@ -31,7 +31,8 @@ def sync_openai():
extra_body=dict( extra_body=dict(
seed=4419, seed=4419,
repetition_penalty=1.3, repetition_penalty=1.3,
)) ),
)
print("transcription result:", transcription.text) print("transcription result:", transcription.text)
...@@ -42,33 +43,30 @@ sync_openai() ...@@ -42,33 +43,30 @@ sync_openai()
async def stream_openai_response(): async def stream_openai_response():
data = { data = {
"language": "en", "language": "en",
'stream': True, "stream": True,
"model": "openai/whisper-large-v3", "model": "openai/whisper-large-v3",
} }
url = openai_api_base + "/audio/transcriptions" url = openai_api_base + "/audio/transcriptions"
headers = {"Authorization": f"Bearer {openai_api_key}"} headers = {"Authorization": f"Bearer {openai_api_key}"}
print("transcription result:", end=' ') print("transcription result:", end=" ")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f: with open(str(winning_call), "rb") as f:
async with client.stream('POST', async with client.stream(
url, "POST", url, files={"file": f}, data=data, headers=headers
files={'file': f}, ) as response:
data=data,
headers=headers) as response:
async for line in response.aiter_lines(): async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: ' # Each line is a JSON object prefixed with 'data: '
if line: if line:
if line.startswith('data: '): if line.startswith("data: "):
line = line[len('data: '):] line = line[len("data: ") :]
# Last chunk, stream ends # Last chunk, stream ends
if line.strip() == '[DONE]': if line.strip() == "[DONE]":
break break
# Parse the JSON response # Parse the JSON response
chunk = json.loads(line) chunk = json.loads(line)
# Extract and print the content # Extract and print the content
content = chunk['choices'][0].get('delta', content = chunk["choices"][0].get("delta", {}).get("content")
{}).get('content') print(content, end="")
print(content, end='')
# Run the asynchronous function # Run the asynchronous function
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import requests import requests
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
OTLPSpanExporter)
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (BatchSpanProcessor, from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
ConsoleSpanExporter)
from opentelemetry.trace import SpanKind, set_tracer_provider from opentelemetry.trace import SpanKind, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import ( from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
TraceContextTextMapPropagator)
trace_provider = TracerProvider() trace_provider = TracerProvider()
set_tracer_provider(trace_provider) set_tracer_provider(trace_provider)
......
# SPDX-License-Identifier: Apache-2.0
"""
vLLM OpenAI-Compatible Client with Prompt Embeddings
This script demonstrates how to:
1. Generate prompt embeddings using Hugging Face Transformers
2. Encode them in base64 format
3. Send them to a vLLM server via the OpenAI-compatible Completions API
Run the vLLM server first:
vllm serve meta-llama/Llama-3.2-1B-Instruct \
--task generate \
--max-model-len 4096 \
--enable-prompt-embeds
Run the client:
python examples/online_serving/prompt_embed_inference_with_openai_client.py
Model: meta-llama/Llama-3.2-1B-Instruct
Note: This model is gated on Hugging Face Hub.
You must request access to use it:
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
Dependencies:
- transformers
- torch
- openai
"""
import base64
import io
import torch
import transformers
from openai import OpenAI
def main():
client = OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
# Refer to the HuggingFace repo for the correct format to use
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
embedding_layer = transformers_model.get_input_embeddings()
prompt_embeds = embedding_layer(token_ids).squeeze(0)
# Prompt embeddings
buffer = io.BytesIO()
torch.save(prompt_embeds, buffer)
buffer.seek(0)
binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
completion = client.completions.create(
model=model_name,
# NOTE: The OpenAI client does not allow `None` as an input to
# `prompt`. Use an empty string if you have no text prompts.
prompt="",
max_tokens=5,
temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the
# `extra_body` argument.
extra_body={"prompt_embeds": encoded_embeds},
)
print("-" * 30)
print(completion.choices[0].text)
print("-" * 30)
if __name__ == "__main__":
main()
...@@ -28,9 +28,7 @@ llm_config = LLMConfig( ...@@ -28,9 +28,7 @@ llm_config = LLMConfig(
}, },
# Change to the accelerator type of the node # Change to the accelerator type of the node
accelerator_type="H100", accelerator_type="H100",
runtime_env={"env_vars": { runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
"VLLM_USE_V1": "1"
}},
# Customize engine arguments as needed (e.g. vLLM engine kwargs) # Customize engine arguments as needed (e.g. vLLM engine kwargs)
engine_kwargs={ engine_kwargs={
"tensor_parallel_size": 8, "tensor_parallel_size": 8,
......
...@@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]): ...@@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]):
Load and split documents from web URL Load and split documents from web URL
""" """
try: try:
loader = WebBaseLoader(web_paths=(config["url"], )) loader = WebBaseLoader(web_paths=(config["url"],))
docs = loader.load() docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
...@@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): ...@@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
""" """
Set up question answering chain Set up question answering chain
""" """
return ({ return (
{
"context": retriever | format_docs, "context": retriever | format_docs,
"question": RunnablePassthrough(), "question": RunnablePassthrough(),
} }
| prompt | prompt
| llm | llm
| StrOutputParser()) | StrOutputParser()
)
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
""" """
Parse command line arguments Parse command line arguments
""" """
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') parser = argparse.ArgumentParser(description="RAG with vLLM and langchain")
# Add command line arguments # Add command line arguments
parser.add_argument('--vllm-api-key', parser.add_argument(
default="EMPTY", "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
help='API key for vLLM compatible services') )
parser.add_argument('--vllm-embedding-endpoint', parser.add_argument(
"--vllm-embedding-endpoint",
default="http://localhost:8000/v1", default="http://localhost:8000/v1",
help='Base URL for embedding service') help="Base URL for embedding service",
parser.add_argument('--vllm-chat-endpoint', )
parser.add_argument(
"--vllm-chat-endpoint",
default="http://localhost:8001/v1", default="http://localhost:8001/v1",
help='Base URL for chat service') help="Base URL for chat service",
parser.add_argument('--uri', )
default="./milvus.db", parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database")
help='URI for Milvus database') parser.add_argument(
"--url",
default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
help="URL of the document to process",
)
parser.add_argument( parser.add_argument(
'--url', "--embedding-model",
default=("https://docs.vllm.ai/en/latest/getting_started/"
"quickstart.html"),
help='URL of the document to process')
parser.add_argument('--embedding-model',
default="ssmits/Qwen2-7B-Instruct-embed-base", default="ssmits/Qwen2-7B-Instruct-embed-base",
help='Model name for embeddings') help="Model name for embeddings",
parser.add_argument('--chat-model', )
default="qwen/Qwen1.5-0.5B-Chat", parser.add_argument(
help='Model name for chat') "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
parser.add_argument('-i', )
'--interactive', parser.add_argument(
action='store_true', "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Enable interactive Q&A mode') )
parser.add_argument('-k', parser.add_argument(
'--top-k', "-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
type=int, )
default=3, parser.add_argument(
help='Number of top results to retrieve') "-c",
parser.add_argument('-c', "--chunk-size",
'--chunk-size',
type=int, type=int,
default=1000, default=1000,
help='Chunk size for document splitting') help="Chunk size for document splitting",
parser.add_argument('-o', )
'--chunk-overlap', parser.add_argument(
"-o",
"--chunk-overlap",
type=int, type=int,
default=200, default=200,
help='Chunk overlap for document splitting') help="Chunk overlap for document splitting",
)
return parser return parser
...@@ -198,7 +205,7 @@ def init_config(args: Namespace): ...@@ -198,7 +205,7 @@ def init_config(args: Namespace):
"url": args.url, "url": args.url,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
...@@ -230,7 +237,7 @@ def main(): ...@@ -230,7 +237,7 @@ def main():
while True: while True:
question = input("\nPlease enter your question: ") question = input("\nPlease enter your question: ")
if question.lower() in ['q', 'quit']: if question.lower() in ["q", "quit"]:
print("\nThank you for using! Goodbye!") print("\nThank you for using! Goodbye!")
break break
...@@ -238,7 +245,7 @@ def main(): ...@@ -238,7 +245,7 @@ def main():
print(output) print(output)
else: else:
# Default single question mode # Default single question mode
question = ("How to install vLLM?") question = "How to install vLLM?"
output = qa_chain.invoke(question) output = qa_chain.invoke(question)
print("-" * 50) print("-" * 50)
print(output) print(output)
......
...@@ -35,6 +35,7 @@ Notes: ...@@ -35,6 +35,7 @@ Notes:
- Default ports: 8000 (embedding), 8001 (chat) - Default ports: 8000 (embedding), 8001 (chat)
- First run may take time to download models - First run may take time to download models
""" """
import argparse import argparse
from argparse import Namespace from argparse import Namespace
from typing import Any from typing import Any
...@@ -59,7 +60,7 @@ def init_config(args: Namespace): ...@@ -59,7 +60,7 @@ def init_config(args: Namespace):
"db_path": args.db_path, "db_path": args.db_path,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
...@@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int): ...@@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int):
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
"""Parse command line arguments""" """Parse command line arguments"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex")
description='RAG with vLLM and LlamaIndex')
# Add command line arguments # Add command line arguments
parser.add_argument( parser.add_argument(
'--url', "--url",
default=("https://docs.vllm.ai/en/latest/getting_started/" default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
"quickstart.html"), help="URL of the document to process",
help='URL of the document to process') )
parser.add_argument('--embedding-model', parser.add_argument(
"--embedding-model",
default="ssmits/Qwen2-7B-Instruct-embed-base", default="ssmits/Qwen2-7B-Instruct-embed-base",
help='Model name for embeddings') help="Model name for embeddings",
parser.add_argument('--chat-model', )
default="qwen/Qwen1.5-0.5B-Chat", parser.add_argument(
help='Model name for chat') "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
parser.add_argument('--vllm-api-key', )
default="EMPTY", parser.add_argument(
help='API key for vLLM compatible services') "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
parser.add_argument('--embedding-endpoint', )
parser.add_argument(
"--embedding-endpoint",
default="http://localhost:8000/v1", default="http://localhost:8000/v1",
help='Base URL for embedding service') help="Base URL for embedding service",
parser.add_argument('--chat-endpoint', )
parser.add_argument(
"--chat-endpoint",
default="http://localhost:8001/v1", default="http://localhost:8001/v1",
help='Base URL for chat service') help="Base URL for chat service",
parser.add_argument('--db-path', )
default="./milvus_demo.db", parser.add_argument(
help='Path to Milvus database') "--db-path", default="./milvus_demo.db", help="Path to Milvus database"
parser.add_argument('-i', )
'--interactive', parser.add_argument(
action='store_true', "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Enable interactive Q&A mode') )
parser.add_argument('-c', parser.add_argument(
'--chunk-size', "-c",
"--chunk-size",
type=int, type=int,
default=1000, default=1000,
help='Chunk size for document splitting') help="Chunk size for document splitting",
parser.add_argument('-o', )
'--chunk-overlap', parser.add_argument(
"-o",
"--chunk-overlap",
type=int, type=int,
default=200, default=200,
help='Chunk overlap for document splitting') help="Chunk overlap for document splitting",
parser.add_argument('-k', )
'--top-k', parser.add_argument(
type=int, "-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
default=3, )
help='Number of top results to retrieve')
return parser return parser
...@@ -193,7 +200,7 @@ def main(): ...@@ -193,7 +200,7 @@ def main():
question = input("\nEnter your question: ") question = input("\nEnter your question: ")
# Check for exit command # Check for exit command
if question.lower() in ['quit', 'exit', 'q']: if question.lower() in ["quit", "exit", "q"]:
print("Exiting interactive mode...") print("Exiting interactive mode...")
break break
......
...@@ -26,6 +26,7 @@ Usage: ...@@ -26,6 +26,7 @@ Usage:
streamlit run streamlit_openai_chatbot_webserver.py \ streamlit run streamlit_openai_chatbot_webserver.py \
--logger.level=debug --logger.level=debug
""" """
import os import os
from datetime import datetime from datetime import datetime
...@@ -33,8 +34,8 @@ import streamlit as st ...@@ -33,8 +34,8 @@ import streamlit as st
from openai import OpenAI from openai import OpenAI
# Get command line arguments from environment variables # Get command line arguments from environment variables
openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY") openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1") openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
# Initialize session states for managing chat sessions # Initialize session states for managing chat sessions
if "sessions" not in st.session_state: if "sessions" not in st.session_state:
...@@ -81,9 +82,9 @@ def get_llm_response(messages, model): ...@@ -81,9 +82,9 @@ def get_llm_response(messages, model):
Streaming response object or error message string Streaming response object or error message string
""" """
try: try:
response = client.chat.completions.create(model=model, response = client.chat.completions.create(
messages=messages, model=model, messages=messages, stream=True
stream=True) )
return response return response
except Exception as e: except Exception as e:
st.error(f"Error details: {str(e)}") st.error(f"Error details: {str(e)}")
...@@ -92,8 +93,9 @@ def get_llm_response(messages, model): ...@@ -92,8 +93,9 @@ def get_llm_response(messages, model):
# Sidebar - API Settings first # Sidebar - API Settings first
st.sidebar.title("API Settings") st.sidebar.title("API Settings")
new_api_base = st.sidebar.text_input("API Base URL:", new_api_base = st.sidebar.text_input(
value=st.session_state.api_base_url) "API Base URL:", value=st.session_state.api_base_url
)
if new_api_base != st.session_state.api_base_url: if new_api_base != st.session_state.api_base_url:
st.session_state.api_base_url = new_api_base st.session_state.api_base_url = new_api_base
st.rerun() st.rerun()
...@@ -109,16 +111,20 @@ if st.sidebar.button("New Session"): ...@@ -109,16 +111,20 @@ if st.sidebar.button("New Session"):
for session_id in sorted(st.session_state.sessions.keys(), reverse=True): for session_id in sorted(st.session_state.sessions.keys(), reverse=True):
# Mark the active session with a pinned button # Mark the active session with a pinned button
if session_id == st.session_state.active_session: if session_id == st.session_state.active_session:
st.sidebar.button(f"📍 {session_id}", st.sidebar.button(
f"📍 {session_id}",
key=session_id, key=session_id,
type="primary", type="primary",
on_click=switch_to_chat_session, on_click=switch_to_chat_session,
args=(session_id, )) args=(session_id,),
)
else: else:
st.sidebar.button(f"Session {session_id}", st.sidebar.button(
f"Session {session_id}",
key=session_id, key=session_id,
on_click=switch_to_chat_session, on_click=switch_to_chat_session,
args=(session_id, )) args=(session_id,),
)
# Main interface # Main interface
st.title("vLLM Chat Assistant") st.title("vLLM Chat Assistant")
...@@ -145,18 +151,18 @@ for message in st.session_state.messages: ...@@ -145,18 +151,18 @@ for message in st.session_state.messages:
if prompt := st.chat_input("Type your message here..."): if prompt := st.chat_input("Type your message here..."):
# Save user message to session # Save user message to session
st.session_state.messages.append({"role": "user", "content": prompt}) st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.sessions[ st.session_state.sessions[st.session_state.current_session] = (
st.session_state.current_session] = st.session_state.messages st.session_state.messages
)
# Display user message # Display user message
with st.chat_message("user"): with st.chat_message("user"):
st.write(prompt) st.write(prompt)
# Prepare messages for llm # Prepare messages for llm
messages_for_llm = [{ messages_for_llm = [
"role": m["role"], {"role": m["role"], "content": m["content"]} for m in st.session_state.messages
"content": m["content"] ]
} for m in st.session_state.messages]
# Generate and display llm response # Generate and display llm response
with st.chat_message("assistant"): with st.chat_message("assistant"):
...@@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."): ...@@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."):
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
# Save llm response to session history # Save llm response to session history
st.session_state.messages.append({ st.session_state.messages.append({"role": "assistant", "content": full_response})
"role": "assistant",
"content": full_response
})
...@@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str: ...@@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str:
f"{client.base_url} with API key {client.api_key}. Check\n" f"{client.base_url} with API key {client.api_key}. Check\n"
"1. the server is running\n" "1. the server is running\n"
"2. the server URL is correct\n" "2. the server URL is correct\n"
"3. the API key is correct") from e "3. the API key is correct"
) from e
if len(models.data) == 0: if len(models.data) == 0:
raise RuntimeError( raise RuntimeError(f"No models found on the vLLM server at {client.base_url}")
f"No models found on the vLLM server at {client.base_url}")
return models.data[0].id return models.data[0].id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment