Commit 118f1fc7 authored by maxiao1's avatar maxiao1
Browse files

sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

parents
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 64
python3 bench_sglang.py --num-questions 32 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --backend vllm --num-questions 64
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --backend lmql --num-questions 32 --parallel 1
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
number = 5
def expand_tip(topic, tip, generate):
s = (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
return generate(s, max_tokens=128, stop=["\n\n"])
def suggest_tips(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
tips = []
for i in range(1, 1 + number):
s += f"{i}."
tip = generate(s, max_tokens=24, stop=[".", "\n"])
s += tip + ".\n"
tips.append(tip)
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
return s
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
states = [None] * len(lines)
# Select backend
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
tic = time.perf_counter()
if args.backend != "lmql":
def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], call_generate)
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(lines)))),
total=len(lines),
)
)
else:
import asyncio
from lmql_funcs import suggest_tips_async
async def get_one_answer_async(i):
states[i] = await suggest_tips_async(lines[i]["topic"], call_generate)
batches = []
for i in range(0, len(lines), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
loop = asyncio.get_event_loop()
for batch in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in batch])
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "tip_suggestion",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="topic.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
number = 5
@sgl.function
def expand_tip(s, topic, tip):
s += (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
@sgl.function
def suggest_tips(s, topic):
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
paragraphs = []
for i in range(1, 1 + number):
s += f"{i}." + sgl.gen(f"tip_{i}", max_tokens=24, stop=[".", "\n"]) + ".\n"
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i - 1]["paragraph"] + "\n"
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{"topic": l["topic"]} for l in lines]
# Select backend
sgl.set_default_backend(select_sglang_backend(args))
# Run requests
tic = time.perf_counter()
states = suggest_tips.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "tip_suggestion",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="topic.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_sglang_args_and_parse(parser)
main(args)
number = 5
async def expand_tip_async(topic, tip, generate):
s = (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
return await generate(s, max_tokens=128, stop="\n\n")
async def suggest_tips_async(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
tips = []
for i in range(1, 1 + number):
s += f"{i}."
# NOTE: stop is different due to lmql does not support a list of stop tokens
tip = await generate(s, max_tokens=24, stop=".\n")
s += tip + ".\n"
tips.append(tip)
paragraphs = [await expand_tip_async(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
return s
{"topic": "organizing a successful charity event", "number": 6}
{"topic": "improving personal credit scores", "number": 7}
{"topic": "staying motivated during job searches", "number": 5}
{"topic": "maintaining a work-life balance", "number": 9}
{"topic": "reducing carbon footprint at home", "number": 8}
{"topic": "starting a book club", "number": 5}
{"topic": "learning to play a musical instrument", "number": 7}
{"topic": "getting into freelance writing", "number": 6}
{"topic": "beginner yoga poses", "number": 8}
{"topic": "preparing for graduate school exams", "number": 5}
{"topic": "exploring minimalist living", "number": 9}
{"topic": "effective grocery shopping", "number": 7}
{"topic": "winter camping", "number": 5}
{"topic": "starting a podcast on a budget", "number": 8}
{"topic": "creating a capsule wardrobe", "number": 6}
{"topic": "improving your writing skills", "number": 7}
{"topic": "learning a new software quickly", "number": 9}
{"topic": "reducing anxiety before public speaking", "number": 5}
{"topic": "planning a solo travel adventure", "number": 8}
{"topic": "beginner skateboarders", "number": 6}
{"topic": "studying abroad", "number": 7}
{"topic": "planting a vegetable garden", "number": 5}
{"topic": "adopting a shelter pet", "number": 9}
{"topic": "learning to cook ethnic cuisines", "number": 8}
{"topic": "effective conflict resolution", "number": 5}
{"topic": "starting a vlog", "number": 7}
{"topic": "keeping a daily journal", "number": 6}
{"topic": "improving sleep hygiene", "number": 8}
{"topic": "beginner mountain climbers", "number": 5}
{"topic": "creating a mobile app", "number": 9}
{"topic": "maintaining a saltwater aquarium", "number": 7}
{"topic": "preparing for a baby's arrival", "number": 6}
{"topic": "writing a fantasy novel", "number": 5}
{"topic": "effective team leadership", "number": 8}
{"topic": "making a documentary film", "number": 9}
{"topic": "learning about historical events", "number": 7}
{"topic": "baking gluten-free treats", "number": 6}
{"topic": "improving mental arithmetic skills", "number": 5}
{"topic": "building a treehouse", "number": 8}
{"topic": "getting started with watercolor painting", "number": 9}
{"topic": "creating a YouTube tutorial series", "number": 7}
{"topic": "landscape photography", "number": 5}
{"topic": "navigating cultural differences", "number": 6}
{"topic": "preparing for a marathon", "number": 8}
{"topic": "building an online business", "number": 9}
{"topic": "learning to dance at home", "number": 5}
{"topic": "self-publishing a book", "number": 7}
{"topic": "starting an urban farm", "number": 6}
{"topic": "improving your memory", "number": 8}
{"topic": "creating a personal brand online", "number": 9}
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
NOTE: This is an implementation for throughput/latency benchmark purposes. The prompts are not tuned to achieve good accuracy on the GSM-8K tasks.
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 32
python3 bench_sglang.py --num-questions 16 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 32 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 32 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 8 --backend lmql --parallel 1
```
import argparse
import ast
import json
import re
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.001
def propose_plan(s, question, num_branches, call_generate):
s += (
USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def get_final_answer(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def tree_search(question, num_branches, call_generate):
plan_forks = propose_plan("", question, num_branches, call_generate)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches, call_generate)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = reflect_solution(sol, num_branches, call_generate)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
ans = get_final_answer(sol, num_branches, call_generate)
solutions.append(ans)
return solutions
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 2
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
call_generate = get_call_generate(args)
# Run requests
states = [None] * len(questions)
tic = time.perf_counter()
if args.backend != "lmql":
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else:
import asyncio
from lmql_funcs import tree_search_async
async def get_one_answer_async(i):
states[i] = await tree_search_async(
**arguments[i], call_generate=call_generate
)
batches = [
[] for _ in range((len(questions) + args.parallel - 1) // args.parallel)
]
for i in range(len(questions)):
batches[i // args.parallel].append(i)
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
tasks = [get_one_answer_async(k) for k in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.perf_counter() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import ast
import json
import re
import time
from collections import Counter
import numpy as np
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.001
def propose_plan(s, question, num_branches):
s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks
def execute_plan(s, num_branches):
s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks
def reflect_solution(s, num_branches):
s += sgl.user(
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks
def get_final_answer(s, num_branches):
s += sgl.user(
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
return forks
@sgl.function
def tree_search(s, question, num_branches):
plan_forks = propose_plan(s, question, num_branches)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = reflect_solution(sol, num_branches)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
forks = get_final_answer(sol, num_branches)
solutions.append(forks)
solutions = [[s.text() for s in forks] for forks in solutions]
return solutions
def main(args):
lines = read_jsonl(args.data_path)
lines = list(lines)
# Construct prompts
num_branches = 2
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.perf_counter()
states = tree_search.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s.ret_value for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)
from bench_other import (
ASSISTANT_PREFIX,
ASSISTANT_SUFFIX,
USER_PREFIX,
USER_SUFFIX,
temp,
)
async def propose_plan_async(s, question, num_branches, call_generate):
s += (
USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def execute_plan_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def reflect_solution_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def get_final_answer_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def tree_search_async(question, num_branches, call_generate):
plan_forks = await propose_plan_async("", question, num_branches, call_generate)
sol_states = []
for plan in plan_forks:
forks = await execute_plan_async(plan, num_branches, call_generate)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = await reflect_solution_async(sol, num_branches, call_generate)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
ans = await get_final_answer_async(sol, num_branches, call_generate)
solutions.append(ans)
return solutions
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 32 --parallel 16
python3 bench_sglang.py --num-questions 10 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 32 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 32 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
import argparse
import ast
import json
import re
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.3
def propose_plan(s, question, num_branches, call_generate):
s += (
USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def tree_search(question, num_branches, call_generate):
s = ""
solutions = []
plan_forks = propose_plan(s, question, num_branches, call_generate)
for plan in plan_forks:
sol_forks = execute_plan(plan, num_branches, call_generate)
for sol in sol_forks:
score_forks = reflect_solution(sol, num_branches, call_generate)
solutions.append(sol_forks)
return solutions
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 3
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
call_generate = get_call_generate(args)
# Run requests
states = [None] * len(questions)
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.perf_counter() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import ast
import json
import re
import time
from collections import Counter
import numpy as np
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.3
def propose_plan(s, question, num_branches):
s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks
def execute_plan(s, num_branches):
s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks
def reflect_solution(s, num_branches):
s += sgl.user(
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks
@sgl.function
def tree_search(s, question, num_branches):
forks_to_join = []
plan_forks = propose_plan(s, question, num_branches)
forks_to_join.append(plan_forks)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches)
forks_to_join.append(forks)
sol_states.extend(forks)
for sol in sol_states:
forks = reflect_solution(sol, num_branches)
forks_to_join.append(forks)
for f in reversed(forks_to_join):
f.join()
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 3
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.perf_counter()
states = tree_search.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s["answer"] for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)
ARG CUDA_VERSION=12.9.1
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 as base
ARG BUILD_TYPE=all
ARG BRANCH_TYPE=remote
ARG DEEPEP_COMMIT=b92d0d4860ce6866cd6d31bfbae937f9a7a3772b
ARG CMAKE_BUILD_PARALLEL_LEVEL=2
ENV DEBIAN_FRONTEND=noninteractive \
CUDA_HOME=/usr/local/cuda \
GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \
NVSHMEM_DIR=/sgl-workspace/nvshmem/install
# Add GKE default lib and bin locations.
ENV PATH="${PATH}:/usr/local/nvidia/bin" \
LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64"
RUN apt update && apt install wget -y && apt install software-properties-common -y \
&& add-apt-repository ppa:deadsnakes/ppa -y \
&& apt install python3.12-full python3.12-dev python3.10-venv -y \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 2 \
&& update-alternatives --set python3 /usr/bin/python3.12 \
&& wget https://bootstrap.pypa.io/get-pip.py \
&& python3 get-pip.py
# Set timezone and install all packages
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update && apt-get install -y --no-install-recommends \
tzdata \
software-properties-common netcat-openbsd kmod unzip openssh-server \
curl wget lsof zsh ccache tmux htop git-lfs tree \
build-essential cmake \
libopenmpi-dev libnuma1 libnuma-dev \
libibverbs-dev libibverbs1 libibumad3 \
librdmacm1 libnl-3-200 libnl-route-3-200 libnl-route-3-dev libnl-3-dev \
ibverbs-providers infiniband-diags perftest \
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
libboost-all-dev libssl-dev \
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \
pybind11-dev \
libhiredis-dev libcurl4-openssl-dev \
libczmq4 libczmq-dev \
libfabric-dev \
patchelf \
nvidia-dkms-550 \
devscripts debhelper fakeroot dkms check libsubunit0 libsubunit-dev \
&& ln -sf /usr/bin/python3.12 /usr/bin/python \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# GDRCopy installation
RUN mkdir -p /tmp/gdrcopy && cd /tmp \
&& git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \
&& cd gdrcopy/packages \
&& CUDA=/usr/local/cuda ./build-deb-packages.sh \
&& dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \
&& cd / && rm -rf /tmp/gdrcopy
# Fix DeepEP IBGDA symlink
RUN ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
FROM scratch AS local_src
COPY . /src
FROM base AS build-image
# Install SGLang
WORKDIR /sgl-workspace
ARG BRANCH_TYPE
COPY --from=local_src /src /tmp/local_src
RUN if [ "$BRANCH_TYPE" = "local" ]; then \
cp -r /tmp/local_src /sgl-workspace/sglang; \
else \
git clone --depth=1 https://github.com/sgl-project/sglang.git /sgl-workspace/sglang; \
fi \
&& rm -rf /tmp/local_src
RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5lib six \
&& cd sglang \
&& case "$CUDA_VERSION" in \
12.6.1) CUINDEX=126 ;; \
12.8.1) CUINDEX=128 ;; \
12.9.1) CUINDEX=129 ;; \
*) echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 ;; \
esac \
&& python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \
&& python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps \
&& python3 -m flashinfer --download-cubin \
&& if [ "$CUDA_VERSION" = "12.6.1" ]; then \
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.9.post2/sgl_kernel-0.3.9.post2+cu124-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \
fi
# Download source files
RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \
git clone https://github.com/deepseek-ai/DeepEP.git && \
cd DeepEP && git checkout ${DEEPEP_COMMIT} && sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \
cd .. && \
tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \
mv nvshmem_src nvshmem && \
rm -f /sgl-workspace/nvshmem_src_cuda12-all-all-3.3.9.tar.gz
# Build and install NVSHMEM
RUN cd /sgl-workspace/nvshmem && \
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES="90" && \
cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL}
# Install DeepEP
RUN cd /sgl-workspace/DeepEP && \
case "$CUDA_VERSION" in \
12.6.1) \
CHOSEN_TORCH_CUDA_ARCH_LIST='9.0' \
;; \
12.8.1|12.9.1) \
CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0' \
;; \
*) \
echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 \
;; \
esac && \
NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install .
# Python tools
RUN python3 -m pip install --no-cache-dir \
datamodel_code_generator \
mooncake-transfer-engine==0.3.5 \
pre-commit \
pytest \
black \
isort \
icdiff \
uv \
wheel \
scikit-build-core \
nixl \
py-spy
# Install development tools and utilities
RUN apt-get update && apt-get install -y \
gdb \
ninja-build \
vim \
tmux \
htop \
wget \
curl \
locales \
lsof \
git \
git-lfs \
zsh \
tree \
silversearcher-ag \
cloc \
unzip \
pkg-config \
libssl-dev \
bear \
ccache \
less \
&& apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN apt update -y \
&& apt install -y --no-install-recommends gnupg \
&& echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64 /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \
&& apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub \
&& apt update -y \
&& apt install nsight-systems-cli -y
# Set up locale
RUN locale-gen en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US:en
ENV LC_ALL en_US.UTF-8
# Install minimal Python packages
RUN python3 -m pip install --no-cache-dir --break-system-packages \
pytest \
black \
isort \
icdiff \
scikit_build_core \
uv \
pre-commit \
pandas \
matplotlib \
tabulate
# Install diff-so-fancy
RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \
&& chmod +x /usr/local/bin/diff-so-fancy
# Install clang-format
RUN curl -LSso /usr/local/bin/clang-format https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/master-32d3ac78/clang-format-16_linux-amd64 \
&& chmod +x /usr/local/bin/clang-format
# Install clangd
RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-linux-18.1.3.zip -o clangd.zip \
&& unzip clangd.zip \
&& cp -r clangd_18.1.3/bin/* /usr/local/bin/ \
&& cp -r clangd_18.1.3/lib/* /usr/local/lib/ \
&& rm -rf clangd_18.1.3 clangd.zip
# Install CMake
RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1-linux-x86_64.tar.gz \
&& tar -xzf cmake-3.31.1-linux-x86_64.tar.gz \
&& cp -r cmake-3.31.1-linux-x86_64/bin/* /usr/local/bin/ \
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
# Install Rust toolchain for sgl-router
ENV PATH="/root/.cargo/bin:${PATH}"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version
# Build and install sgl-router
RUN python3 -m pip install --no-cache-dir setuptools-rust \
&& cd /sgl-workspace/sglang/sgl-router \
&& cargo build --release \
&& python3 -m pip install --no-cache-dir . \
&& rm -rf /root/.cache
# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
put() {
esc=$1
test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\"
printf "$esc"
}
put "\033]52;c;!\a"
buf=$( cat "$@" )
len=$( printf %s "$buf" | wc -c ) max=74994
test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2
put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a"
test -n "$TMUX" && tmux set-buffer "$buf" ||:
EOF
RUN chmod +x /usr/local/bin/yank
# Install oh-my-zsh and plugins
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \
&& git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \
&& git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting
# Configure Vim
COPY --chown=root:root <<-"EOF" /root/.vimrc
function! Yank(text) abort
let escape = system('yank', a:text)
if v:shell_error
echoerr escape
else
call writefile([escape], '/dev/tty', 'b')
endif
endfunction
noremap <silent> <Leader>y y:<C-U>call Yank(@0)<CR>
" automatically run yank(1) whenever yanking in Vim
function! CopyYank() abort
call Yank(join(v:event.regcontents, "\n"))
endfunction
autocmd TextYankPost * call CopyYank()
" Basic settings
set number
syntax on
set mouse=a
filetype indent on
" Indentation
set autoindent nosmartindent
set smarttab
set expandtab
set shiftwidth=4
set softtabstop=4
" Visual guides
set colorcolumn=120
highlight ColorColumn ctermbg=5
" Status line
set laststatus=2
set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P
" Backspace behavior
set backspace=2
" Encoding
set encoding=utf-8
set fileencoding=utf-8
EOF
# Configure tmux
COPY --chown=root:root <<-"EOF" /root/.tmux.conf
# Pane border styling
set -g pane-border-style fg='#742727',bg=black
set -g pane-active-border-style fg=red,bg=black
# Status bar styling
set -g status-style bg='#0C8A92',fg=black
# Change prefix key to backtick
set-option -g prefix `
unbind C-b
bind-key ` send-prefix
# Split panes using - and = with current path
unbind '"'
bind - splitw -v -c '#{pane_current_path}'
unbind '%'
bind = splitw -h -c '#{pane_current_path}'
# Vi mode settings
bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}'
set-window-option -g mode-keys vi
# Other settings
set-option -g escape-time 0
set-option -g base-index 1
set-window-option -g mouse on
set -g history-limit 100000
EOF
# Configure Git
RUN git config --global core.editor "vim" \
&& git config --global core.whitespace "fix,-indent-with-non-tab,trailing-space,cr-at-eol" \
&& git config --global core.pager "diff-so-fancy | less --tabs=4 -RFX" \
&& git config --global color.ui true \
&& git config --global color."diff-highlight".oldNormal "red bold" \
&& git config --global color."diff-highlight".oldHighlight "red bold 52" \
&& git config --global color."diff-highlight".newNormal "green bold" \
&& git config --global color."diff-highlight".newHighlight "green bold 22" \
&& git config --global color.diff.meta "11" \
&& git config --global color.diff.frag "magenta bold" \
&& git config --global color.diff.commit "yellow bold" \
&& git config --global color.diff.old "red bold" \
&& git config --global color.diff.new "green bold" \
&& git config --global color.diff.whitespace "red reverse" \
&& git config --global alias.lg "log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --" \
&& git config --global http.sslVerify false \
&& git config --global pull.rebase true
# Configure zsh
COPY --chown=root:root <<-"EOF" /root/.zshrc
export ZSH="/root/.oh-my-zsh"
# Theme
ZSH_THEME="robbyrussell"
# Plugins
plugins=(
git
z
zsh-autosuggestions
zsh-syntax-highlighting
)
source $ZSH/oh-my-zsh.sh
# Aliases
alias ll='ls -alF'
alias la='ls -A'
alias l='ls -CF'
alias vi='vim'
# Enhanced history
HISTSIZE=10000
SAVEHIST=10000
setopt HIST_IGNORE_ALL_DUPS
setopt HIST_FIND_NO_DUPS
setopt INC_APPEND_HISTORY
EOF
RUN set -euxo ; \
curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- --to /usr/local/bin
# Set workspace directory
WORKDIR /sgl-workspace/sglang
ARG CUDA_VERSION=12.9.1
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
ARG BUILD_TYPE=blackwell
ARG DEEPEP_COMMIT=1b14ad661c7640137fcfe93cccb2694ede1220b0
ARG CMAKE_BUILD_PARALLEL_LEVEL=2
ARG SGL_KERNEL_VERSION=0.3.9.post2
ENV DEBIAN_FRONTEND=noninteractive \
CUDA_HOME=/usr/local/cuda \
GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \
NVSHMEM_DIR=/sgl-workspace/nvshmem/install \
BUILD_TYPE=${BUILD_TYPE} \
TORCH_CUDA_ARCH_LIST="10.0 12.0"
# Set timezone and install all packages
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update && apt-get install -y --no-install-recommends \
tzdata \
software-properties-common netcat-openbsd kmod unzip openssh-server \
curl wget lsof zsh ccache tmux htop git-lfs tree \
python3 python3-pip python3-dev libpython3-dev python3-venv \
build-essential cmake \
libopenmpi-dev libnuma1 libnuma-dev \
libibverbs-dev libibverbs1 libibumad3 \
librdmacm1 libnl-3-200 libnl-route-3-200 libnl-route-3-dev libnl-3-dev \
ibverbs-providers infiniband-diags perftest \
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
libboost-all-dev libssl-dev \
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \
pybind11-dev \
libhiredis-dev libcurl4-openssl-dev \
libczmq4 libczmq-dev \
libfabric-dev \
patchelf \
nvidia-dkms-550 \
devscripts debhelper fakeroot dkms check libsubunit0 libsubunit-dev \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# Install SGLang missing package for blackwell build type
RUN python3 -m pip install openai httpx
# GDRCopy installation
RUN mkdir -p /tmp/gdrcopy && cd /tmp \
&& git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \
&& cd gdrcopy/packages \
&& CUDA=/usr/local/cuda ./build-deb-packages.sh \
&& dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \
&& cd / && rm -rf /tmp/gdrcopy
# Fix DeepEP IBGDA symlink
RUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so
# Clone and install SGLang
WORKDIR /sgl-workspace
RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5lib six \
&& git clone --depth 1 https://github.com/sgl-project/sglang.git \
&& cd sglang \
&& case "$CUDA_VERSION" in \
12.9.1) CUINDEX=129 ;; \
*) echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 ;; \
esac \
&& if [ "$CUDA_VERSION" = "12.9.1" ]; then \
python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps ; \
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu129-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps ; \
fi \
&& python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \
&& python3 -m flashinfer --download-cubin
# Download source files
RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \
git clone https://github.com/fzyzcjy/DeepEP.git && \
cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd .. && \
tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \
mv nvshmem_src nvshmem && \
rm -f /sgl-workspace/nvshmem_src_cuda12-all-all-3.3.9.tar.gz
# Build and install NVSHMEM
RUN cd /sgl-workspace/nvshmem && \
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES="90;100;120" && \
cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL}
# Install DeepEP
RUN cd /sgl-workspace/DeepEP && \
NVSHMEM_DIR=${NVSHMEM_DIR} pip install .
# Python tools
RUN python3 -m pip install --no-cache-dir \
datamodel_code_generator \
mooncake-transfer-engine==0.3.5 \
pre-commit \
pytest \
black \
isort \
icdiff \
uv \
wheel \
scikit-build-core
# Install nixl kv transfer backend
RUN python3 -m pip install --no-cache-dir \
nixl
# Install development tools and utilities
RUN apt-get update && apt-get install -y \
gdb \
ninja-build \
vim \
tmux \
htop \
wget \
curl \
locales \
lsof \
git \
git-lfs \
zsh \
tree \
silversearcher-ag \
cloc \
unzip \
pkg-config \
libssl-dev \
bear \
ccache \
less \
&& apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN apt update -y \
&& apt install -y --no-install-recommends gnupg \
&& echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "amd64"; fi) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \
&& apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "x86_64"; fi)/7fa2af80.pub \
&& apt update -y \
&& apt install nsight-systems-cli -y
# Set up locale
RUN locale-gen en_US.UTF-8
ENV LANG=en_US.UTF-8
ENV LANGUAGE=en_US:en
ENV LC_ALL=en_US.UTF-8
# Install minimal Python packages
RUN python3 -m pip install --no-cache-dir --break-system-packages \
pytest \
black \
isort \
icdiff \
scikit_build_core \
uv \
pre-commit \
pandas \
matplotlib \
tabulate
# Install diff-so-fancy
RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \
&& chmod +x /usr/local/bin/diff-so-fancy
# Install clang-format
RUN curl -LSso /usr/local/bin/clang-format https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/master-32d3ac78/clang-format-16_linux-amd64 \
&& chmod +x /usr/local/bin/clang-format
# Install clangd
RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-linux-18.1.3.zip -o clangd.zip \
&& unzip clangd.zip \
&& cp -r clangd_18.1.3/bin/* /usr/local/bin/ \
&& cp -r clangd_18.1.3/lib/* /usr/local/lib/ \
&& rm -rf clangd_18.1.3 clangd.zip
# Install CMake
RUN CMAKE_VERSION=3.31.1 \
&& ARCH=$(uname -m) \
&& CMAKE_INSTALLER="cmake-${CMAKE_VERSION}-linux-${ARCH}" \
&& wget "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_INSTALLER}.tar.gz" \
&& tar -xzf "${CMAKE_INSTALLER}.tar.gz" \
&& cp -r "${CMAKE_INSTALLER}/bin/"* /usr/local/bin/ \
&& cp -r "${CMAKE_INSTALLER}/share/"* /usr/local/share/ \
&& rm -rf "${CMAKE_INSTALLER}" "${CMAKE_INSTALLER}.tar.gz"
# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
put() {
esc=$1
test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\"
printf "$esc"
}
put "\033]52;c;!\a"
buf=$( cat "$@" )
len=$( printf %s "$buf" | wc -c ) max=74994
test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2
put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a"
test -n "$TMUX" && tmux set-buffer "$buf" ||:
EOF
RUN chmod +x /usr/local/bin/yank
# Install oh-my-zsh and plugins
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \
&& git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \
&& git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting
# Configure Vim
COPY --chown=root:root <<-"EOF" /root/.vimrc
function! Yank(text) abort
let escape = system('yank', a:text)
if v:shell_error
echoerr escape
else
call writefile([escape], '/dev/tty', 'b')
endif
endfunction
noremap <silent> <Leader>y y:<C-U>call Yank(@0)<CR>
" automatically run yank(1) whenever yanking in Vim
function! CopyYank() abort
call Yank(join(v:event.regcontents, "\n"))
endfunction
autocmd TextYankPost * call CopyYank()
" Basic settings
set number
syntax on
set mouse=a
filetype indent on
" Indentation
set autoindent nosmartindent
set smarttab
set expandtab
set shiftwidth=4
set softtabstop=4
" Visual guides
set colorcolumn=120
highlight ColorColumn ctermbg=5
" Status line
set laststatus=2
set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P
" Backspace behavior
set backspace=2
" Encoding
set encoding=utf-8
set fileencoding=utf-8
EOF
# Configure tmux
COPY --chown=root:root <<-"EOF" /root/.tmux.conf
# Pane border styling
set -g pane-border-style fg='#742727',bg=black
set -g pane-active-border-style fg=red,bg=black
# Status bar styling
set -g status-style bg='#0C8A92',fg=black
# Change prefix key to backtick
set-option -g prefix `
unbind C-b
bind-key ` send-prefix
# Split panes using - and = with current path
unbind '"'
bind - splitw -v -c '#{pane_current_path}'
unbind '%'
bind = splitw -h -c '#{pane_current_path}'
# Vi mode settings
bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}'
set-window-option -g mode-keys vi
# Other settings
set-option -g escape-time 0
set-option -g base-index 1
set-window-option -g mouse on
EOF
# Configure Git
RUN git config --global core.editor "vim" \
&& git config --global core.whitespace "fix,-indent-with-non-tab,trailing-space,cr-at-eol" \
&& git config --global core.pager "diff-so-fancy | less --tabs=4 -RFX" \
&& git config --global color.ui true \
&& git config --global color."diff-highlight".oldNormal "red bold" \
&& git config --global color."diff-highlight".oldHighlight "red bold 52" \
&& git config --global color."diff-highlight".newNormal "green bold" \
&& git config --global color."diff-highlight".newHighlight "green bold 22" \
&& git config --global color.diff.meta "11" \
&& git config --global color.diff.frag "magenta bold" \
&& git config --global color.diff.commit "yellow bold" \
&& git config --global color.diff.old "red bold" \
&& git config --global color.diff.new "green bold" \
&& git config --global color.diff.whitespace "red reverse" \
&& git config --global alias.lg "log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --" \
&& git config --global http.sslVerify false \
&& git config --global pull.rebase true
# Configure zsh
COPY --chown=root:root <<-"EOF" /root/.zshrc
export ZSH="/root/.oh-my-zsh"
# Theme
ZSH_THEME="robbyrussell"
# Plugins
plugins=(
git
z
zsh-autosuggestions
zsh-syntax-highlighting
)
source $ZSH/oh-my-zsh.sh
# Aliases
alias ll='ls -alF'
alias la='ls -A'
alias l='ls -CF'
alias vi='vim'
# Enhanced history
HISTSIZE=10000
SAVEHIST=10000
setopt HIST_IGNORE_ALL_DUPS
setopt HIST_FIND_NO_DUPS
setopt INC_APPEND_HISTORY
EOF
RUN set -euxo ; \
curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- --to /usr/local/bin
# Set workspace directory
WORKDIR /sgl-workspace/sglang
ARG CANN_VERSION=8.2.rc1
ARG DEVICE_TYPE=a3
ARG OS=ubuntu22.04
ARG PYTHON_VERSION=py3.11
FROM quay.io/ascend/cann:$CANN_VERSION-$DEVICE_TYPE-$OS-$PYTHON_VERSION
# Update pip & apt sources
ARG PIP_INDEX_URL="https://pypi.org/simple/"
ARG APTMIRROR=""
ARG MEMFABRIC_URL=https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl
ARG PYTORCH_VERSION=2.6.0
ARG TORCHVISION_VERSION=0.21.0
ARG PTA_URL="https://gitee.com/ascend/pytorch/releases/download/v7.1.0.1-pytorch2.6.0/torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl"
ARG VLLM_TAG=v0.8.5
ARG TRITON_ASCEND_URL=https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl
ARG SGLANG_TAG=main
ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit
ARG SGLANG_KERNEL_NPU_TAG=main
WORKDIR /workspace
# Define environments
ENV DEBIAN_FRONTEND=noninteractive
RUN pip config set global.index-url $PIP_INDEX_URL
RUN if [ -n "$APTMIRROR" ];then sed -i "s|.*.ubuntu.com|$APTMIRROR|g" /etc/apt/sources.list ;fi
# Install development tools and utilities
RUN apt-get update -y && apt upgrade -y && apt-get install -y \
build-essential \
cmake \
vim \
wget \
curl \
net-tools \
zlib1g-dev \
lld \
clang \
locales \
ccache \
openssl \
libssl-dev \
pkg-config \
ca-certificates \
protobuf-compiler \
&& rm -rf /var/cache/apt/* \
&& rm -rf /var/lib/apt/lists/* \
&& update-ca-certificates \
&& locale-gen en_US.UTF-8
ENV LANG=en_US.UTF-8
ENV LANGUAGE=en_US:en
ENV LC_ALL=en_US.UTF-8
ENV PATH="/root/.cargo/bin:${PATH}"
# Install dependencies
# TODO: install from pypi released memfabric
RUN pip install $MEMFABRIC_URL --no-cache-dir
RUN pip install setuptools-rust wheel build --no-cache-dir
# install rustup from rustup.rs
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version && protoc --version
# Install vLLM
RUN git clone --depth 1 https://github.com/vllm-project/vllm.git --branch $VLLM_TAG && \
(cd vllm && VLLM_TARGET_DEVICE="empty" pip install -v . --no-cache-dir) && rm -rf vllm
# TODO: install from pypi released triton-ascend
RUN pip install torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION --index-url https://download.pytorch.org/whl/cpu --no-cache-dir \
&& wget ${PTA_URL} && pip install "./torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl" --no-cache-dir \
&& python3 -m pip install --no-cache-dir attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 \
&& pip install ${TRITON_ASCEND_URL} --no-cache-dir
# Install SGLang
RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \
(cd sglang/python && pip install -v .[srt_npu] --no-cache-dir) && \
(cd sglang/sgl-router && python -m build && pip install --force-reinstall dist/*.whl) && \
rm -rf sglang
# Install Deep-ep
RUN git clone --branch $SGLANG_KERNEL_NPU_TAG https://github.com/sgl-project/sgl-kernel-npu.git \
&& export LD_LIBRARY_PATH=${ASCEND_CANN_PATH}/latest/runtime/lib64/stub:$LD_LIBRARY_PATH && \
source ${ASCEND_CANN_PATH}/set_env.sh && \
cd sgl-kernel-npu && \
bash build.sh \
&& pip install output/deep_ep*.whl --no-cache-dir \
&& cd .. && rm -rf sgl-kernel-npu \
&& cd "$(pip show deep-ep | awk '/^Location:/ {print $2}')" && ln -s deep_ep/deep_ep_cpp*.so
CMD ["/bin/bash"]
# Usage (to build SGLang ROCm docker image):
# docker build --build-arg SGL_BRANCH=v0.5.2 --build-arg GPU_ARCH=gfx942 -t v0.5.2-rocm630-mi30x -f Dockerfile.rocm .
# docker build --build-arg SGL_BRANCH=v0.5.2 --build-arg GPU_ARCH=gfx942-rocm700 -t v0.5.2-rocm700-mi30x -f Dockerfile.rocm .
# docker build --build-arg SGL_BRANCH=v0.5.2 --build-arg GPU_ARCH=gfx950 -t v0.5.2-rocm700-mi35x -f Dockerfile.rocm .
# Default base images
ARG BASE_IMAGE_942="rocm/sgl-dev:vllm20250114"
ARG BASE_IMAGE_942_ROCM700="rocm/sgl-dev:rocm7-vllm-20250904"
ARG BASE_IMAGE_950="rocm/sgl-dev:rocm7-vllm-20250904"
# This is necessary for scope purpose
ARG GPU_ARCH=gfx950
# ===============================
# Base image 942 with rocm630 and args
FROM $BASE_IMAGE_942 AS gfx942
ENV BUILD_VLLM="0"
ENV BUILD_TRITON="1"
ENV BUILD_LLVM="0"
ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.4"
ENV NO_DEPS_FLAG=""
# ===============================
# Base image 942 and args
FROM $BASE_IMAGE_942_ROCM700 AS gfx942-rocm700
ENV BUILD_VLLM="0"
ENV BUILD_TRITON="0"
ENV BUILD_LLVM="0"
ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.5"
ENV NO_DEPS_FLAG=""
# ===============================
# Base image 950 and args
FROM $BASE_IMAGE_950 AS gfx950
ENV BUILD_VLLM="0"
ENV BUILD_TRITON="0"
ENV BUILD_LLVM="0"
ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.5"
ENV NO_DEPS_FLAG="--no-deps"
# ===============================
# Chosen arch and args
FROM ${GPU_ARCH}
# This is necessary for scope purpose, again
ARG GPU_ARCH=gfx950
ENV GPU_ARCH_LIST=${GPU_ARCH%-*}
ARG SGL_REPO="https://github.com/sgl-project/sglang.git"
ARG SGL_DEFAULT="main"
ARG SGL_BRANCH=${SGL_DEFAULT}
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG LLVM_REPO="https://github.com/jrbyrnes/llvm-project.git"
ARG LLVM_BRANCH="MainOpSelV2"
ARG LLVM_COMMIT="6520ace8227ffe2728148d5f3b9872a870b0a560"
ARG MOONCAKE_REPO="https://github.com/kvcache-ai/Mooncake.git"
ARG MOONCAKE_COMMIT="dcdf1c784b40aa6975a8ed89fe26321b028e40e8"
USER root
# Install some basic utilities
RUN python -m pip install --upgrade pip && pip install setuptools_scm
RUN apt-get purge -y sccache; python -m pip uninstall -y sccache; rm -f "$(which sccache)"
WORKDIR /sgl-workspace
# -----------------------
# llvm
RUN if [ "$BUILD_LLVM" = "1" ]; then \
ENV HIP_CLANG_PATH="/sgl-workspace/llvm-project/build/bin/" \
git clone --single-branch ${LLVM_REPO} -b ${LLVM_BRANCH} \
&& cd llvm-project \
&& git checkout ${LLVM_COMMIT} \
&& mkdir build \
&& cd build \
&& cmake -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm \
&& make -j$(nproc); \
fi
# -----------------------
# -----------------------
# AITER
RUN pip uninstall -y aiter
RUN git clone ${AITER_REPO} \
&& cd aiter \
&& git checkout ${AITER_COMMIT} \
&& git submodule update --init --recursive
RUN cd aiter \
&& if [ "$BUILD_AITER_ALL" = "1" ] && [ "$BUILD_LLVM" = "1" ]; then \
HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop; \
elif [ "$BUILD_AITER_ALL" = "1" ]; then \
PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop; \
else \
GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop; \
fi
# -----------------------
# Triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
pip uninstall -y triton \
&& git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_COMMIT} \
&& cd python \
&& python setup.py install; \
fi
# -----------------------
# Build vLLM
ARG VLLM_REPO="https://github.com/ROCm/vllm.git"
ARG VLLM_BRANCH="9f6b92db47c3444b7a7d67451ba0c3a2d6af4c2c"
RUN if [ "$BUILD_VLLM" = "1" ]; then \
git clone ${VLLM_REPO} \
&& cd vllm \
&& git checkout ${VLLM_BRANCH} \
&& python -m pip install -r requirements/rocm.txt \
&& python setup.py clean --all \
&& python setup.py develop; \
fi
# -----------------------
# Build Mooncake
ENV PATH=$PATH:/usr/local/go/bin
RUN if [ "$BUILD_MOONCAKE" = "1" ]; then \
apt update && apt install -y zip unzip wget && \
apt install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev rdma-core && \
apt install -y openssh-server openmpi-bin openmpi-common libopenmpi-dev && \
git clone ${MOONCAKE_REPO} && \
cd Mooncake && \
git checkout ${MOONCAKE_COMMIT} && \
git submodule update --init --recursive && \
bash dependencies.sh -y && \
rm -rf /usr/local/go && \
wget https://go.dev/dl/go1.22.2.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go1.22.2.linux-amd64.tar.gz && \
rm go1.22.2.linux-amd64.tar.gz && \
mkdir -p build && \
cd build && \
cmake .. -DUSE_ETCD=ON && \
make -j "$(nproc)" && make install; \
fi
# -----------------------
# Build SGLang
ARG BUILD_TYPE=all
RUN pip install IPython \
&& pip install orjson \
&& pip install python-multipart \
&& pip install torchao==0.9.0 \
&& pip install pybind11
RUN pip uninstall -y sgl_kernel sglang
RUN git clone ${SGL_REPO} \
&& cd sglang \
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
echo "Using ${SGL_DEFAULT}, default branch."; \
git checkout ${SGL_DEFAULT}; \
else \
echo "Using ${SGL_BRANCH} branch."; \
git checkout ${SGL_BRANCH}; \
fi \
&& cd sgl-kernel \
&& rm -f pyproject.toml \
&& mv pyproject_rocm.toml pyproject.toml \
&& AMDGPU_TARGET=$GPU_ARCH_LIST python setup_rocm.py install \
&& cd .. \
&& if [ "$BUILD_TYPE" = "srt" ]; then \
python -m pip --no-cache-dir install -e "python[srt_hip]" ${NO_DEPS_FLAG}; \
else \
python -m pip --no-cache-dir install -e "python[all_hip]" ${NO_DEPS_FLAG}; \
fi
RUN python -m pip cache purge
# Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build.
RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \
/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \
-type f -name '*MI300X*' | xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {}
# Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1
ENV HSA_NO_SCRATCH_RECLAIM=1
ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112
ENV SGLANG_USE_AITER=1
ENV SGLANG_MOE_PADDING=1
ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1
ENV VLLM_FP8_REDUCE_CONV=1
ENV TORCHINDUCTOR_MAX_AUTOTUNE=1
ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1
CMD ["/bin/bash"]
######################## BASE IMAGE ##########################
FROM ubuntu:24.04 AS base
ARG PYTHON_VERSION=3.12
# set the environment variables
ENV PATH="/root/.local/bin:${PATH}"
ENV DEBIAN_FRONTEND=noninteractive
# uv environment variables
ENV UV_HTTP_TIMEOUT=500
ENV VIRTUAL_ENV="/opt/venv"
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
ENV UV_LINK_MODE="copy"
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# install dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt update -y \
&& apt install -y curl \
&& rm -rf /var/lib/apt/lists/* \
&& apt clean
# install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
# install python
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
######################### BUILD IMAGE #########################
FROM base AS build-image
ARG SGLANG_REPO_REF=main
# set the environment variables
ENV PATH="/root/.cargo/bin:${PATH}"
# install dependencies
RUN apt update -y \
&& apt install -y git build-essential libssl-dev pkg-config protobuf-compiler \
&& rm -rf /var/lib/apt/lists/* \
&& apt clean
# install rustup from rustup.rs
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version && protoc --version
# pull the github repository
RUN cd /opt \
&& git clone --depth=1 https://github.com/sgl-project/sglang.git \
&& cd /opt/sglang \
&& git checkout ${SGLANG_REPO_REF}
# working directory
WORKDIR /opt/sglang/sgl-router
# build the rust dependencies
RUN cargo build --release \
&& uv build \
&& rm -rf /root/.cache
######################### ROUTER IMAGE #########################
FROM base AS router-image
# Copy the built package from the build image
COPY --from=build-image /opt/sglang/sgl-router/dist/*.whl dist/
# Build the package and install
RUN uv pip install --force-reinstall dist/*.whl
# Clean up unnecessary files to reduce the image size
RUN rm -rf /root/.cache \
&& apt purge -y --auto-remove curl
# Set the entrypoint to the main command
ENTRYPOINT ["python3", "-m", "sglang_router.launch_router"]
FROM lmsysorg/sglang:latest
COPY serve /usr/bin/serve
RUN chmod 777 /usr/bin/serve
ENTRYPOINT [ "/usr/bin/serve" ]
FROM ubuntu:24.04
SHELL ["/bin/bash", "-c"]
ARG VER_SGLANG=main
ARG VER_TORCH=2.7.1
ARG VER_TORCHVISION=0.22.1
ARG VER_TRITON=3.3.1
RUN apt-get update && \
apt-get full-upgrade -y && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
ca-certificates \
git \
curl \
wget \
vim \
gcc \
g++ \
make
WORKDIR /sgl-workspace
RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/releases/download/24.11.3-2/Miniforge3-24.11.3-2-Linux-x86_64.sh && \
bash miniforge.sh -b -p ./miniforge3 && \
rm -f miniforge.sh && \
. miniforge3/bin/activate && \
conda install -y libsqlite==3.48.0 gperftools tbb libnuma numactl
ENV PATH=/sgl-workspace/miniforge3/bin:/sgl-workspace/miniforge3/condabin:${PATH}
ENV PIP_ROOT_USER_ACTION=ignore
ENV CONDA_PREFIX=/sgl-workspace/miniforge3
RUN pip config set global.index-url https://download.pytorch.org/whl/cpu && \
pip config set global.extra-index-url https://pypi.org/simple
RUN git clone https://github.com/sgl-project/sglang.git && \
cd sglang && \
git checkout ${VER_SGLANG} && \
pip install -e "python[all_cpu]" && \
pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} triton==${VER_TRITON} --force-reinstall && \
cd sgl-kernel && \
cp pyproject_cpu.toml pyproject.toml && \
pip install .
ENV SGLANG_USE_CPU_ENGINE=1
ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2
WORKDIR /sgl-workspace/sglang
services:
sglang:
image: lmsysorg/sglang:latest
container_name: sglang
volumes:
- ${HOME}/.cache/huggingface:/root/.cache/huggingface
# If you use modelscope, you need mount this directory
# - ${HOME}/.cache/modelscope:/root/.cache/modelscope
restart: always
network_mode: host # required by RDMA
privileged: true # required by RDMA
# Or you can only publish port 30000
# ports:
# - 30000:30000
environment:
HF_TOKEN: <secret>
# if you use modelscope to download model, you need set this environment
# - SGLANG_USE_MODELSCOPE: true
entrypoint: python3 -m sglang.launch_server
command: --model-path meta-llama/Llama-3.1-8B-Instruct
--host 0.0.0.0
--port 30000
ulimits:
memlock: -1
stack: 67108864
ipc: host
healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:30000/health || exit 1"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["0"]
capabilities: [gpu]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment