Unverified Commit 2e341cd4 authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

misc: add pre-commit config (#637)

parent a8552cb1
...@@ -3,10 +3,12 @@ Usage: ...@@ -3,10 +3,12 @@ Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python json_decode.py python json_decode.py
""" """
from enum import Enum from enum import Enum
import sglang as sgl
from pydantic import BaseModel from pydantic import BaseModel
import sglang as sgl
from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained import build_regex_from_object
character_regex = ( character_regex = (
......
...@@ -14,16 +14,13 @@ Output: ...@@ -14,16 +14,13 @@ Output:
import argparse import argparse
import asyncio import asyncio
import copy
import json import json
import time import time
import copy
import aiohttp import aiohttp
import requests import requests
from llava.conversation import conv_llava_llama_3
from llava.conversation import (
conv_llava_llama_3,
)
async def send_request(url, data, delay=0): async def send_request(url, data, delay=0):
......
...@@ -14,16 +14,13 @@ Output: ...@@ -14,16 +14,13 @@ Output:
import argparse import argparse
import asyncio import asyncio
import copy
import json import json
import time import time
import copy
import aiohttp import aiohttp
import requests import requests
from llava.conversation import conv_qwen
from llava.conversation import (
conv_qwen
)
async def send_request(url, data, delay=0): async def send_request(url, data, delay=0):
......
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
Usage: python3 srt_example_llava.py Usage: python3 srt_example_llava.py
""" """
from PIL import ImageFile
import sglang as sgl import sglang as sgl
from sglang.srt.utils import load_image
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function @sgl.function
def image_qa(s, image, question): def image_qa(s, image, question):
s += sgl.user(sgl.image(image) + question) s += sgl.user(sgl.image(image) + question)
......
...@@ -2,15 +2,17 @@ ...@@ -2,15 +2,17 @@
Usage: python3 srt_example_llava.py Usage: python3 srt_example_llava.py
""" """
import sglang as sgl import argparse
import os
import csv import csv
import os
import time import time
import argparse
import sglang as sgl
@sgl.function @sgl.function
def video_qa(s, num_frames, video_path, question): def video_qa(s, num_frames, video_path, question):
s += sgl.user(sgl.video(video_path,num_frames) + question) s += sgl.user(sgl.video(video_path, num_frames) + question)
s += sgl.assistant(sgl.gen("answer")) s += sgl.assistant(sgl.gen("answer"))
...@@ -25,7 +27,6 @@ def single(path, num_frames=16): ...@@ -25,7 +27,6 @@ def single(path, num_frames=16):
print(state["answer"], "\n") print(state["answer"], "\n")
def split_into_chunks(lst, num_chunks): def split_into_chunks(lst, num_chunks):
"""Split a list into a specified number of chunks.""" """Split a list into a specified number of chunks."""
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible. # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
...@@ -34,7 +35,7 @@ def split_into_chunks(lst, num_chunks): ...@@ -34,7 +35,7 @@ def split_into_chunks(lst, num_chunks):
if chunk_size == 0: if chunk_size == 0:
chunk_size = len(lst) chunk_size = len(lst)
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible. # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
# Ensure we have exactly num_chunks chunks, even if some are empty # Ensure we have exactly num_chunks chunks, even if some are empty
chunks.extend([[] for _ in range(num_chunks - len(chunks))]) chunks.extend([[] for _ in range(num_chunks - len(chunks))])
return chunks return chunks
...@@ -42,67 +43,73 @@ def split_into_chunks(lst, num_chunks): ...@@ -42,67 +43,73 @@ def split_into_chunks(lst, num_chunks):
def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir): def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
with open(csv_filename, 'w', newline='') as csvfile: with open(csv_filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile) writer = csv.writer(csvfile)
writer.writerow(['video_name', 'answer']) writer.writerow(["video_name", "answer"])
for video_path, state in zip(batch_video_files, states): for video_path, state in zip(batch_video_files, states):
video_name = os.path.basename(video_path) video_name = os.path.basename(video_path)
writer.writerow([video_name, state["answer"]]) writer.writerow([video_name, state["answer"]])
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir): def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv" final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
with open(final_csv_filename, 'w', newline='') as final_csvfile: with open(final_csv_filename, "w", newline="") as final_csvfile:
writer = csv.writer(final_csvfile) writer = csv.writer(final_csvfile)
writer.writerow(['video_name', 'answer']) writer.writerow(["video_name", "answer"])
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
with open(batch_csv_filename, 'r') as batch_csvfile: with open(batch_csv_filename, "r") as batch_csvfile:
reader = csv.reader(batch_csvfile) reader = csv.reader(batch_csvfile)
next(reader) # Skip header row next(reader) # Skip header row
for row in reader: for row in reader:
writer.writerow(row) writer.writerow(row)
os.remove(batch_csv_filename) os.remove(batch_csv_filename)
def find_video_files(video_dir): def find_video_files(video_dir):
# Check if the video_dir is actually a file # Check if the video_dir is actually a file
if os.path.isfile(video_dir): if os.path.isfile(video_dir):
# If it's a file, return it as a single-element list # If it's a file, return it as a single-element list
return [video_dir] return [video_dir]
# Original logic to find video files in a directory # Original logic to find video files in a directory
video_files = [] video_files = []
for root, dirs, files in os.walk(video_dir): for root, dirs, files in os.walk(video_dir):
for file in files: for file in files:
if file.endswith(('.mp4', '.avi', '.mov')): if file.endswith((".mp4", ".avi", ".mov")):
video_files.append(os.path.join(root, file)) video_files.append(os.path.join(root, file))
return video_files return video_files
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64): def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
video_files = find_video_files(video_dir) video_files = find_video_files(video_dir)
chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk] chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
num_batches = 0 num_batches = 0
for i in range(0, len(chunked_video_files), batch_size): for i in range(0, len(chunked_video_files), batch_size):
batch_video_files = chunked_video_files[i:i + batch_size] batch_video_files = chunked_video_files[i : i + batch_size]
print(f"Processing batch of {len(batch_video_files)} video(s)...") print(f"Processing batch of {len(batch_video_files)} video(s)...")
if not batch_video_files: if not batch_video_files:
print("No video files found in the specified directory.") print("No video files found in the specified directory.")
return return
batch_input = [ batch_input = [
{ {
"num_frames": num_frames, "num_frames": num_frames,
"video_path": video_path, "video_path": video_path,
"question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.", "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
} for video_path in batch_video_files }
for video_path in batch_video_files
] ]
start_time = time.time() start_time = time.time()
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2) states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
total_time = time.time() - start_time total_time = time.time() - start_time
average_time = total_time / len(batch_video_files) average_time = total_time / len(batch_video_files)
print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds") print(
f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
)
save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir) save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
num_batches += 1 num_batches += 1
...@@ -113,16 +120,47 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= ...@@ -113,16 +120,47 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
if __name__ == "__main__": if __name__ == "__main__":
# Create the parser # Create the parser
parser = argparse.ArgumentParser(description='Run video processing with specified port.') parser = argparse.ArgumentParser(
description="Run video processing with specified port."
)
# Add an argument for the port # Add an argument for the port
parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.') parser.add_argument(
parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.') "--port",
parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.') type=int,
parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.') default=30000,
parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.') help="The master port for distributed serving.",
parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.') )
parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' ) parser.add_argument(
"--chunk-idx", type=int, default=0, help="The index of the chunk to process."
)
parser.add_argument(
"--num-chunks", type=int, default=8, help="The number of chunks to process."
)
parser.add_argument(
"--save-dir",
type=str,
default="./work_dirs/llava_video",
help="The directory to save the processed video files.",
)
parser.add_argument(
"--video-dir",
type=str,
default="./videos/Q98Z4OTh8RwmDonc.mp4",
help="The directory or path for the processed video files.",
)
parser.add_argument(
"--model-path",
type=str,
default="lmms-lab/LLaVA-NeXT-Video-7B",
help="The model path for the video processing.",
)
parser.add_argument(
"--num-frames",
type=int,
default=16,
help="The number of frames to process in each video.",
)
parser.add_argument("--mm_spatial_pool_stride", type=int, default=2) parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
# Parse the arguments # Parse the arguments
...@@ -154,7 +192,6 @@ if __name__ == "__main__": ...@@ -154,7 +192,6 @@ if __name__ == "__main__":
if "34b" in args.model_path.lower(): if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002 model_overide_args["image_token_index"] = 64002
if args.num_frames == 32: if args.num_frames == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2 model_overide_args["max_sequence_length"] = 4096 * 2
...@@ -162,22 +199,22 @@ if __name__ == "__main__": ...@@ -162,22 +199,22 @@ if __name__ == "__main__":
elif args.num_frames < 32: elif args.num_frames < 32:
pass pass
else: else:
print("The maximum number of frames to process is 32. Please specify a valid number of frames.") print(
"The maximum number of frames to process is 32. Please specify a valid number of frames."
)
exit() exit()
runtime = sgl.Runtime( runtime = sgl.Runtime(
model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b", model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
port=cur_port, port=cur_port,
additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4], additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
tp_size=1 tp_size=1,
) )
sgl.set_default_backend(runtime) sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}") print(f"chat template: {runtime.endpoint.chat_template.name}")
# Run a single request # Run a single request
# try: # try:
print("\n========== single ==========\n") print("\n========== single ==========\n")
...@@ -185,24 +222,29 @@ if __name__ == "__main__": ...@@ -185,24 +222,29 @@ if __name__ == "__main__":
if os.path.isfile(root): if os.path.isfile(root):
video_files = [root] video_files = [root]
else: else:
video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed video_files = [
os.path.join(root, f)
for f in os.listdir(root)
if f.endswith((".mp4", ".avi", ".mov"))
] # Add more extensions if needed
start_time = time.time() # Start time for processing a single video start_time = time.time() # Start time for processing a single video
for cur_video in video_files[:1]: for cur_video in video_files[:1]:
print(cur_video) print(cur_video)
single(cur_video, num_frames) single(cur_video, num_frames)
end_time = time.time() # End time for processing a single video end_time = time.time() # End time for processing a single video
total_time = end_time - start_time total_time = end_time - start_time
average_time = total_time / len(video_files) # Calculate the average processing time average_time = total_time / len(
video_files
) # Calculate the average processing time
print(f"Average processing time per video: {average_time:.2f} seconds") print(f"Average processing time per video: {average_time:.2f} seconds")
runtime.shutdown() runtime.shutdown()
# except Exception as e: # except Exception as e:
# print(e) # print(e)
runtime.shutdown() runtime.shutdown()
# # # Run a batch of requests # # # Run a batch of requests
# print("\n========== batch ==========\n") # print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir): # if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir) # os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown() # runtime.shutdown()
\ No newline at end of file
...@@ -15,23 +15,40 @@ incorrect: ...@@ -15,23 +15,40 @@ incorrect:
export OPENAI_API_KEY=sk-****** export OPENAI_API_KEY=sk-******
python3 openai_chat_speculative.py python3 openai_chat_speculative.py
""" """
import sglang as sgl import sglang as sgl
from sglang import function, set_default_backend, OpenAI from sglang import OpenAI, function, set_default_backend
@function(num_api_spec_tokens=256) @function(num_api_spec_tokens=256)
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n") s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
@function(num_api_spec_tokens=256) @function(num_api_spec_tokens=256)
def gen_character_spec_no_few_shot(s): def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n") s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nAge:"
+ sgl.gen("age", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
@function @function
...@@ -45,10 +62,19 @@ def gen_character_normal(s): ...@@ -45,10 +62,19 @@ def gen_character_normal(s):
def multi_turn_question(s, question_1, question_2): def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Answer questions in the following format:") s += sgl.user("Answer questions in the following format:")
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n") s += sgl.user(
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n") "Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n"
s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2) )
s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n")) s += sgl.assistant(
"Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n"
)
s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2)
s += sgl.assistant(
"Answer 1: "
+ sgl.gen("answer_1", stop="\n")
+ "\nAnswer 2: "
+ sgl.gen("answer_2", stop="\n")
)
def test_spec_single_turn(): def test_spec_single_turn():
...@@ -97,7 +123,7 @@ def test_spec_multi_turn_stream(): ...@@ -97,7 +123,7 @@ def test_spec_multi_turn_stream():
state = multi_turn_question.run( state = multi_turn_question.run(
question_1="What is the capital of the United States?", question_1="What is the capital of the United States?",
question_2="List two local attractions.", question_2="List two local attractions.",
stream=True stream=True,
) )
for out in state.text_iter(): for out in state.text_iter():
...@@ -126,4 +152,4 @@ if __name__ == "__main__": ...@@ -126,4 +152,4 @@ if __name__ == "__main__":
print("\n========== test spec multi turn stream ==========\n") print("\n========== test spec multi turn stream ==========\n")
# expect error in stream_executor: stream is not supported... # expect error in stream_executor: stream is not supported...
test_spec_multi_turn_stream() test_spec_multi_turn_stream()
\ No newline at end of file
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
Usage: Usage:
python3 openai_speculative.py python3 openai_speculative.py
""" """
from sglang import function, gen, set_default_backend, OpenAI
from sglang import OpenAI, function, gen, set_default_backend
@function(num_api_spec_tokens=64) @function(num_api_spec_tokens=64)
...@@ -35,7 +36,11 @@ if __name__ == "__main__": ...@@ -35,7 +36,11 @@ if __name__ == "__main__":
backend = OpenAI("gpt-3.5-turbo-instruct") backend = OpenAI("gpt-3.5-turbo-instruct")
set_default_backend(backend) set_default_backend(backend)
for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]: for function in [
gen_character_spec,
gen_character_no_spec,
gen_character_spec_no_few_shot,
]:
backend.token_usage.reset() backend.token_usage.reset()
print(f"function: {function.func.__name__}") print(f"function: {function.func.__name__}")
...@@ -46,4 +51,4 @@ if __name__ == "__main__": ...@@ -46,4 +51,4 @@ if __name__ == "__main__":
print("...birthday:", state["birthday"]) print("...birthday:", state["birthday"])
print("...job:", state["job"]) print("...job:", state["job"])
print(backend.token_usage) print(backend.token_usage)
print() print()
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Usage: Usage:
python3 parallel_sample.py python3 parallel_sample.py
""" """
import sglang as sgl import sglang as sgl
...@@ -12,7 +13,6 @@ def parallel_sample(s, question, n): ...@@ -12,7 +13,6 @@ def parallel_sample(s, question, n):
"Reasoning: I need to use a calculator.\n" "Reasoning: I need to use a calculator.\n"
"Tool: calculator\n" "Tool: calculator\n"
"Answer: 6\n" "Answer: 6\n"
"Question: Compute 3 + 2 + 2\n" "Question: Compute 3 + 2 + 2\n"
"Reasoning: I will try a calculator.\n" "Reasoning: I will try a calculator.\n"
"Tool: calculator\n" "Tool: calculator\n"
...@@ -27,13 +27,9 @@ def parallel_sample(s, question, n): ...@@ -27,13 +27,9 @@ def parallel_sample(s, question, n):
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
#sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) # sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
state = parallel_sample.run( state = parallel_sample.run(question="Compute 5 + 2 + 4.", n=5, temperature=1.0)
question="Compute 5 + 2 + 4.",
n=5,
temperature=1.0
)
for i in range(5): for i in range(5):
obj = { obj = {
......
...@@ -3,13 +3,18 @@ Usage: ...@@ -3,13 +3,18 @@ Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python readme_examples.py python readme_examples.py
""" """
import sglang as sgl import sglang as sgl
@sgl.function @sgl.function
def tool_use(s, question): def tool_use(s, question):
s += "To answer this question: " + question + ". " s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " s += (
"I need to use a "
+ sgl.gen("tool", choices=["calculator", "search engine"])
+ ". "
)
if s["tool"] == "calculator": if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression") s += "The math expression is" + sgl.gen("expression")
...@@ -75,7 +80,7 @@ def driver_batching(): ...@@ -75,7 +80,7 @@ def driver_batching():
{"question": "What is the capital of France?"}, {"question": "What is the capital of France?"},
{"question": "What is the capital of Japan?"}, {"question": "What is the capital of Japan?"},
], ],
progress_bar=True progress_bar=True,
) )
for s in states: for s in states:
...@@ -85,9 +90,7 @@ def driver_batching(): ...@@ -85,9 +90,7 @@ def driver_batching():
def driver_stream(): def driver_stream():
state = text_qa.run( state = text_qa.run(
question="What is the capital of France?", question="What is the capital of France?", temperature=0.1, stream=True
temperature=0.1,
stream=True
) )
for out in state.text_iter(): for out in state.text_iter():
...@@ -96,7 +99,7 @@ def driver_stream(): ...@@ -96,7 +99,7 @@ def driver_stream():
if __name__ == "__main__": if __name__ == "__main__":
#sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) # sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_tool_use() driver_tool_use()
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
Usage: Usage:
python3 streaming.py python3 streaming.py
""" """
import asyncio import asyncio
import sglang as sgl import sglang as sgl
...@@ -22,7 +24,7 @@ def stream_a_variable(): ...@@ -22,7 +24,7 @@ def stream_a_variable():
state = multi_turn_question.run( state = multi_turn_question.run(
question_1="What is the capital of the United States?", question_1="What is the capital of the United States?",
question_2="List two local attractions.", question_2="List two local attractions.",
stream=True stream=True,
) )
for out in state.text_iter(var_name="answer_2"): for out in state.text_iter(var_name="answer_2"):
...@@ -34,7 +36,7 @@ async def async_stream(): ...@@ -34,7 +36,7 @@ async def async_stream():
state = multi_turn_question.run( state = multi_turn_question.run(
question_1="What is the capital of the United States?", question_1="What is the capital of the United States?",
question_2="List two local attractions.", question_2="List two local attractions.",
stream=True stream=True,
) )
async for out in state.text_async_iter(var_name="answer_2"): async for out in state.text_async_iter(var_name="answer_2"):
......
import triton_python_backend_utils as pb_utils
import numpy import numpy
import triton_python_backend_utils as pb_utils
from pydantic import BaseModel
import sglang as sgl import sglang as sgl
from sglang import function, set_default_backend from sglang import function, set_default_backend
from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained import build_regex_from_object
from pydantic import BaseModel
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
class Character(BaseModel): class Character(BaseModel):
name: str name: str
eye_color: str eye_color: str
house: str house: str
@function @function
def character_gen(s, name): def character_gen(s, name):
s += ( s += (
name name
+ " is a character in Harry Potter. Please fill in the following information about this character.\n" + " is a character in Harry Potter. Please fill in the following information about this character.\n"
) )
s += sgl.gen("json_output", max_tokens=256, regex=build_regex_from_object(Character)) s += sgl.gen(
"json_output", max_tokens=256, regex=build_regex_from_object(Character)
)
class TritonPythonModel: class TritonPythonModel:
def initialize(self, args): def initialize(self, args):
print("Initialized.") print("Initialized.")
def execute(self, requests): def execute(self, requests):
responses = [] responses = []
for request in requests: for request in requests:
tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT") tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
if tensor_in is None: if tensor_in is None:
return pb_utils.InferenceResponse(output_tensors=[]) return pb_utils.InferenceResponse(output_tensors=[])
input_list_names = [i.decode('utf-8') if isinstance(i, bytes) else i for i in tensor_in.as_numpy().tolist()]
input_list_dicts = [{"name":i} for i in input_list_names] input_list_names = [
i.decode("utf-8") if isinstance(i, bytes) else i
for i in tensor_in.as_numpy().tolist()
]
input_list_dicts = [{"name": i} for i in input_list_names]
states = character_gen.run_batch(input_list_dicts) states = character_gen.run_batch(input_list_dicts)
character_strs = [state.text() for state in states] character_strs = [state.text() for state in states]
tensor_out = pb_utils.Tensor("OUTPUT_TEXT", numpy.array(character_strs, dtype=object)) tensor_out = pb_utils.Tensor(
"OUTPUT_TEXT", numpy.array(character_strs, dtype=object)
)
responses.append(pb_utils.InferenceResponse(output_tensors = [tensor_out])) responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out]))
return responses return responses
\ No newline at end of file
...@@ -3,11 +3,12 @@ import code ...@@ -3,11 +3,12 @@ import code
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") parser.add_argument(
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
args = parser.parse_args() args = parser.parse_args()
t = get_tokenizer(args.name) t = get_tokenizer(args.name)
code.interact(local=locals()) code.interact(local=locals())
\ No newline at end of file
...@@ -183,14 +183,18 @@ class CudaGraphRunner: ...@@ -183,14 +183,18 @@ class CudaGraphRunner:
else: else:
output = LogitProcessorOutput( output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=output.next_token_logprobs[:raw_bs] next_token_logprobs=(
if output.next_token_logprobs is not None output.next_token_logprobs[:raw_bs]
else None, if output.next_token_logprobs is not None
else None
),
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, prefill_token_logprobs=None,
prefill_top_logprobs=None, prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] decode_top_logprobs=(
if output.decode_top_logprobs is not None output.decode_top_logprobs[:raw_bs]
else None, if output.decode_top_logprobs is not None
else None
),
) )
return output return output
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
import multiprocessing
import logging import logging
import multiprocessing
import os import os
import pickle import pickle
...@@ -11,11 +11,10 @@ import zmq ...@@ -11,11 +11,10 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.managers.controller.tp_worker import ModelTpServer from sglang.srt.managers.controller.tp_worker import ModelTpServer
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller") logger = logging.getLogger("srt.controller")
...@@ -45,14 +44,16 @@ def run_tp_server( ...@@ -45,14 +44,16 @@ def run_tp_server(
raise raise
def launch_tp_servers(gpu_ids, tp_rank_range, server_args, def launch_tp_servers(
model_port_args, model_overide_args): gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args
):
"""Launch multiple tp servers.""" """Launch multiple tp servers."""
procs = [] procs = []
for i in tp_rank_range: for i in tp_rank_range:
proc = multiprocessing.Process(target=run_tp_server, args=( proc = multiprocessing.Process(
gpu_ids[i], i, server_args, model_port_args, model_overide_args target=run_tp_server,
)) args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args),
)
proc.start() proc.start()
procs.append(proc) procs.append(proc)
...@@ -93,7 +94,9 @@ def broadcast_recv_input(data, rank, dist_group): ...@@ -93,7 +94,9 @@ def broadcast_recv_input(data, rank, dist_group):
class ControllerSingle: class ControllerSingle:
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict): def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict
):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.tp_procs = [] self.tp_procs = []
...@@ -116,8 +119,12 @@ class ControllerSingle: ...@@ -116,8 +119,12 @@ class ControllerSingle:
if tp_size_local > 1: if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local) tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers( self.tp_procs = launch_tp_servers(
gpu_ids, tp_rank_range, server_args, gpu_ids,
port_args.model_port_args[0], model_overide_args) tp_rank_range,
server_args,
port_args.model_port_args[0],
model_overide_args,
)
# Launch tp rank 0 # Launch tp rank 0
self.tp_server = ModelTpServer( self.tp_server = ModelTpServer(
......
...@@ -11,7 +11,11 @@ import torch ...@@ -11,7 +11,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group from vllm.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
...@@ -89,9 +93,9 @@ class ModelRunner: ...@@ -89,9 +93,9 @@ class ModelRunner:
# Set some global args # Set some global args
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
global_server_args_dict[ global_server_args_dict["attention_reduce_in_fp32"] = (
"attention_reduce_in_fp32" server_args.attention_reduce_in_fp32
] = server_args.attention_reduce_in_fp32 )
# Load the model and create memory pool # Load the model and create memory pool
self.load_model() self.load_model()
......
...@@ -241,12 +241,9 @@ class ModelTpServer: ...@@ -241,12 +241,9 @@ class ModelTpServer:
def print_stats(self): def print_stats(self):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
) )
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
logger.info( logger.info(
...@@ -260,8 +257,7 @@ class ModelTpServer: ...@@ -260,8 +257,7 @@ class ModelTpServer:
def check_memory(self): def check_memory(self):
available_size = ( available_size = (
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
+ self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_tokens: if available_size != self.max_total_num_tokens:
warnings.warn( warnings.warn(
...@@ -348,7 +344,8 @@ class ModelTpServer: ...@@ -348,7 +344,8 @@ class ModelTpServer:
if self.running_batch: if self.running_batch:
available_size -= sum( available_size -= sum(
[ [
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio (r.sampling_params.max_new_tokens - len(r.output_ids))
* self.new_token_ratio
for r in self.running_batch.reqs for r in self.running_batch.reqs
] ]
) )
...@@ -370,7 +367,9 @@ class ModelTpServer: ...@@ -370,7 +367,9 @@ class ModelTpServer:
req.image_offset += 1 req.image_offset += 1
if ( if (
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens req.extend_input_len
+ req.sampling_params.max_new_tokens
+ new_batch_total_tokens
< available_size < available_size
and ( and (
req.extend_input_len + new_batch_input_tokens req.extend_input_len + new_batch_input_tokens
...@@ -382,7 +381,9 @@ class ModelTpServer: ...@@ -382,7 +381,9 @@ class ModelTpServer:
available_size += delta available_size += delta
if not ( if not (
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens req.extend_input_len
+ req.sampling_params.max_new_tokens
+ new_batch_total_tokens
< available_size < available_size
): ):
# Undo locking # Undo locking
......
...@@ -335,15 +335,16 @@ class TokenizerManager: ...@@ -335,15 +335,16 @@ class TokenizerManager:
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"][ ret["meta_info"]["prefill_top_logprobs"] = (
"prefill_top_logprobs" self.detokenize_top_logprobs_tokens(
] = self.detokenize_top_logprobs_tokens( ret["meta_info"]["prefill_top_logprobs"],
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs return_text_in_logprobs,
)
) )
ret["meta_info"][ ret["meta_info"]["decode_top_logprobs"] = (
"decode_top_logprobs" self.detokenize_top_logprobs_tokens(
] = self.detokenize_top_logprobs_tokens( ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs )
) )
return ret return ret
......
...@@ -21,7 +21,9 @@ class ReqToTokenPool: ...@@ -21,7 +21,9 @@ class ReqToTokenPool:
if need_size > self.can_use_mem_size: if need_size > self.can_use_mem_size:
return None return None
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32) select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
)
self.mem_state[select_index] = False self.mem_state[select_index] = False
self.can_use_mem_size -= need_size self.can_use_mem_size -= need_size
...@@ -79,7 +81,9 @@ class TokenToKVPool: ...@@ -79,7 +81,9 @@ class TokenToKVPool:
addition_size = need_size - buffer_len addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size) alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32) select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size: if select_index.shape[0] < addition_size:
return None return None
......
...@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None config, "original_max_position_embeddings", None
): ):
rope_scaling[ rope_scaling["original_max_position_embeddings"] = (
"original_max_position_embeddings" config.original_max_position_embeddings
] = config.original_max_position_embeddings )
rope_is_neox_style = getattr(config, "rope_is_neox_style", True) rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
......
...@@ -313,7 +313,10 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -313,7 +313,10 @@ class Qwen2ForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.config.tie_word_embeddings and name=="model.embed_tokens.weight": if (
self.config.tie_word_embeddings
and name == "model.embed_tokens.weight"
):
weight_loader(params_dict["lm_head.weight"], loaded_weight) weight_loader(params_dict["lm_head.weight"], loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment