Commit 909abb58 authored by maxiao's avatar maxiao
Browse files

adapt to sglang v0.5.2rc1 on dcu

parents
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
access: proxy
url: http://localhost:9090
isDefault: true
editable: false
# prometheus.yaml
global:
scrape_interval: 5s
evaluation_interval: 30s
scrape_configs:
- job_name: sglang
static_configs:
- targets:
- '127.0.0.1:30000'
# gputrc2graph.py
This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files
(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level
summaries and visualizations of GPU and non-GPU time. It is useful for
profiling and analyzing nsys profile output.
## Usage
### Command-line Arguments
- `--in_file`
**(required)**
List of input files and their metadata. Each entry should be in the format:
`<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>`
- `nsys-rep`: Path to the `.nsys-rep` file.
- `engine`: Engine name (e.g., `sglang`).
- `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`).
- `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without
profiling. Specify `0` to use the elapsed time from the nsys-rep file
(this may inflate non-GPU time if actual runtime without profiling is
less). Multiple entries can be provided, separated by spaces.
- `--out_dir`
Output directory for the generated CSV and HTML files.
If not specified, results are saved in the current directory.
- `--title`
Title for the HTML chart/visualization.
- `--nsys_cmd`
Path to the `nsys` command.
Default: `nsys` (assumes it is in your PATH).
Use this if `nsys` is not in your system PATH.
## Notes
- Make sure you have pandas installed. Any version is fine.
- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is
installed, and specify the path to the `nsys` command with `--nsys_cmd` if it
is not in your PATH. The nsys version must be >= the nsys profile version that
was used to collect the traces when profiling the server, so that nsys can
process the nsys-rep that was generated.
- For more details on available engines and models, see the help string in
the script or run:
```bash
python3 gputrc2graph.py --help
```
## Example 1: analyze a single profile
To analyze the GPU cycles of for example, a llama-3.1-8B model with sglang:
1. Run the following command to collect nsys profile, for sglang server config.
```bash
nsys profile -t cuda -o nsys_res -f true --trace-fork-before-exec=true \
--cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B ...
```
where:
- DELAY: how many seconds to delay nsys from collecting profiles, needed so
that profiles aren't captured till sglang server has come up and load
generation starts.
- DURATION: how many seconds for nsys profile to run before generating the
profile. This should be > the duration of the run.
2. After the server starts, run the client load generation command. Once the
test completes, after DURATION amount of time, nsys profile will generate an
nsys_res.nsys-rep file and shut down the server.
3. Run step #1 again, this time starting up the server without collecting the
profile.
4. Run step #2 again, and record the total time to complete the test in
seconds. This value will be used by the script to calculate the
CPU(non-GPU) seconds for the analysis.
5. Say the run elapsed time from step #4 is 132 seconds. Run script to
analyze:
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,sglang,llama,132
```
The command will produce 2 files for analysis:
- result.html: this categorizes kernel names into different categories in a
stacked bar chart.
- result.csv: shows how the kernel names are mapped to the different
categories.
### HTML visualization with result.html
The html file shows the number of elapsed seconds due to different GPU
Substages or categories, which consist of attention kernels as the biggest
category, at 63 seconds, followed by "gemm" kernels. This lets the user
prioritize the kernels to focus on for performance optimizations.
There's also an appended data table underneath the bar chart for copying out to
other post-processing tools.
### Kernel to category mapping with result.csv
Suppose the user would like to focus on improving triton kernels. It's not the
biggest consumer of cycles at .01 sec but perhaps it hasn't been optimized.
The next step is to use the result.csv to dive into what the kernels are which
compose the triton kernel GPU cycles.
## Example 2: analyze multiple profiles
Suppose the user has multiple nsys trace files, captured for different models,
say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU
time, something like the following command can be used.
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,sglang,llama,100 run2.nsys-rep,sglang,gpt-oss,102 \
--out_dir results
```
The analysis process is similar to example 1 but now there will be multiple
stack bar charts that can be compared. The categories for the different
kernels will remain the same, so that it's easy to compare the GPU cycles for
the same categories.
Once a category is shown to have more cycles for one configuration than
another, the next step would be to use the csv file to see what kernels are
mapped into that category, and which kernels are taking the largest amount of
time which would cause a difference for the overall category.
## Example 3: add new classification for a new model
To create a new engine DEF with model ABC, just add another json file in the same directory as
gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications.
Then, for this new model, suppose there are 4 kernels to be classified into
"gemm" and "attn", where the gemm kernels have names with "*H*" or "*I*" in
them, and attn kernels have names with "*J*" or "*K*" in them, just add another
.json file in the same directory as gputrc2graph.py with the same format as
the other json files, like the following:
```json
{
"DEF": {
"ABC": {
"H|I": "gemm",
"J|K": "attn",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
```
Each entry in the dictionary consists of:
- key: a regex used to classify the kernels
- value: the category to classify the kernels into.
The last 2 entries are common for all engine/models, consisting of CUDA memory
operations and a 'misc' for anything that's leftover and can't be classified.
When invoking gputrc2graph.py, specify a trace file with this new model/engine
like the following:
```bash
--in_file new.nsys-rep,DEF,ABC,<runtime>
```
If the engine_DEF.json file already exists, just add the model as a new node in
the existing engine file, after the other models.
"""
This generates gpu kernel analysis output from nsys rep. Will call nsys
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
csv and html output for analysis
"""
import argparse
import logging
import os
import regex as re
logger = logging.getLogger(__name__)
# helper data class for annotating kernels
def load_engine_model():
"""returns engine_model built from all json files in the current dir"""
import glob
import json
engine_model = {}
json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json"))
for fname in json_files:
with open(fname, encoding="utf-8") as f:
engine_model.update(json.load(f))
return engine_model
class GPUTrace2Graph:
"""
Parses output of nsys report, generates csv and bar chart output
"""
def __init__(self):
import pandas as pd # avoid importing till needed
self.pd = pd
self.pd.options.mode.copy_on_write = True
# helper functions for generating trace->summary csvs
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
logger.info("loading %s", in_file)
df = self.pd.read_csv(
in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"]
)
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
df = self.sum_non_overlapping_intervals(df)
# get ready to print table with elapsed times per kernel
df["Instances"] = 1
df_sum = df.groupby("Name", as_index=False).agg(
{"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"}
)
# generate csv
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv(
out_file, index=False
)
def sum_non_overlapping_intervals(self, df):
"""
returns new sorted df with Elapsed Time (ns) column using
vectorized operations
"""
logger.info("sorting %s trace records by start time", str(df.shape))
# Sort by start time and reset index
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
# Initialize elapsed time as duration
df["Elapsed Time (ns)"] = df["Duration (ns)"]
# Get numpy arrays for faster operations
starts = df["Start (ns)"].values
ends = df["End (ns)"].values
# Keep track of current interval end
current_end = ends[0]
display_units = max(1, int(len(df) / 100))
# Update current_end for overlapping intervals
for i in range(1, len(df)):
if i % display_units == 0:
print(f"processing trace: {int(i/len(df) * 100)} %", end="\r")
if starts[i] <= current_end:
if ends[i] > current_end:
# Partial overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
ends[i] - current_end
)
current_end = ends[i]
else:
# Complete overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
else:
# No overlap
current_end = ends[i]
return df
# functions for generating html files
def make_html(self, df, output_dir, title):
"""make html graph from df"""
import plotly.express as px
if df.empty:
return
output_name = os.path.join(output_dir, "result")
if not title:
title = "Model_Engine"
x = "Model_Engine"
y = "Elapsed Time (sec)"
color = "Category"
""" generate kernel mapping table """
# Sort Model_Engine categories by last field after underscore
df["Model_Engine"] = self.pd.Categorical(
df["Model_Engine"],
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
)
df[["Model_Engine", color, "Instances", "Name", y]].sort_values(
by=color
).to_csv(f"{output_name}.csv", index=False)
graph = px.histogram(
df.round(2),
x=x,
y=y,
title=(f"{y} for {title}"),
color=color,
text_auto=True,
)
# wrap x axis labels
graph.update_xaxes(automargin=True)
graph.write_html(f"{output_name}.html")
"""
Generate data table with columns per Model_Engine into result.html
"""
pivot_df = df.pivot_table(
values="Elapsed Time (sec)",
index="Category",
columns="Model_Engine",
aggfunc="sum",
observed=False,
).round(2)
# Add sum row at bottom
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
pivot_df.fillna("").to_html("temp.html")
with (
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
open("temp.html", encoding="utf-8") as infile,
):
outfile.write(infile.read())
os.remove("temp.html")
print(
f"Finished generating: \n"
f" {output_name}.html for stack bar chart \n"
f" {output_name}.csv for Kernel-Category mapping"
)
def anno_gpu_kernname(self, df, mapping):
"""add "Category" column"""
def anno_gpu_kernname_helper(name):
for kern_name, val in mapping.items():
if re.search(kern_name, name):
return val
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
def make_nongpu_row(self, df, nongpu_sec):
"""this will append non-gpu time entry at end of df"""
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
nongpu_row["Instances"] = 1
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
return nongpu_row
def is_valid_file(self, base_file):
"""asserts if base_file is non-existent or is empty"""
assert (
os.path.isfile(base_file) and os.path.getsize(base_file) > 0
), f"{base_file} doesn't exist or is empty"
def should_gen_file(self, new_file, base_file):
"""figure out if new file should be generated from base_file"""
self.is_valid_file(base_file)
if (
os.path.exists(new_file)
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
and (os.path.getsize(base_file) > 0)
):
logger.info("reusing %s", new_file)
return False
else:
logger.info("generating %s", new_file)
return True
def gen_sum_file(self, file, nsys_cmd):
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
"""
import subprocess
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
# Walk through trace and get the total non-overlapped time
nsys_stats_file = os.path.join(file_dir, f"{file_name}_cuda_gpu_trace.csv")
sum_file = os.path.join(file_dir, f"{file_name}_cuda_gpu_kernel_tracesum.csv")
if self.should_gen_file(nsys_stats_file, file):
cmd = [
nsys_cmd,
"stats",
"-r",
"cuda_gpu_trace",
file,
"-o",
f"{file_dir}/{file_name}",
]
cmd_str = " ".join(cmd)
logger.info("+ %s", cmd_str)
# estimate time based on calibrated 240M/min
file_size_mb = os.path.getsize(file) / 1e6
logger.info(
"nsys stats for %.2f MB file expected to take %.2f min",
file_size_mb,
file_size_mb / 240,
)
try:
subprocess.run(cmd, check=True)
except (FileNotFoundError, subprocess.CalledProcessError) as e:
logger.error(
"'%s' failed: %s. Use --nsys_cmd to specify nsys path", cmd_str, e
)
exit(1)
logger.info("generating non-overalapped sum %s", sum_file)
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
self.is_valid_file(sum_file)
logger.info("Finished generating %s", sum_file)
return sum_file
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
"""generates graph and csv file from in_file into out_dir"""
# Initialize an empty DataFrame to store combined data
combined_df = self.pd.DataFrame()
for idx, (file, engine, model, total_sec) in enumerate(in_file):
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
sum_file = self.gen_sum_file(file, nsys_cmd)
# read kernel summary file
df = self.pd.read_csv(sum_file)
# annotate kernel to their categories
assert engine_model.get(engine), f"engine {engine} unknown"
assert engine_model[engine].get(model), f"model {model} unknown"
# remove nsys-rep from file_name for shorter x-label
file_name = file_name.replace(".nsys-rep", "")
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
self.anno_gpu_kernname(df, engine_model[engine][model])
# patch in non-gpu time
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
total_sec = round(float(total_sec), 1)
if total_sec < gpu_sec:
logger.warning(
"Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ",
total_sec,
gpu_sec,
)
total_sec = gpu_sec
nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec)
df = self.pd.concat([df, nongpu_row], ignore_index=True)
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
if out_dir is None:
out_dir = "."
else:
os.makedirs(out_dir, exist_ok=True)
# generate html file
self.make_html(combined_df, out_dir, title)
def parse_tuple(s):
return tuple(s.split(","))
def main():
logging.basicConfig(
format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO
)
parser = argparse.ArgumentParser(
description=(
"Process nsys rep and generate kernel non-overlapped cycles. \n"
"Example:\n"
"gputrc2graph.py --in_file d1.nsys-rep,sglang,llama,100 \n"
"d2.nsys-rep,sglang,gpt-oss,102 "
'--out_dir results/ --title "Model=gpt-oss SGLANG chart"'
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# load supported engine_model
engine_model_supported = load_engine_model()
# Get a string representation of supported engine/model combinations
engine_model_supported_str = ", ".join(
f"{engine}:[{', '.join(models.keys())}]"
for engine, models in engine_model_supported.items()
)
parser.add_argument(
"--in_file",
type=parse_tuple,
nargs="+",
help=(
"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
"separated by space. Elapsed_nonprofiled_sec is runtime without "
"profiling used to calculate non-gpu time. Specify 0 to use "
"elapsed time from nsys-rep but that might inflate non-gpu time. "
f"Available engine:[model] are: {engine_model_supported_str} "
f"Example: --infile d1.nsys-rep,sglan,llama,100 "
"d2.nsys-rep,sglang,gpt-oss,102"
),
required=True,
)
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
parser.add_argument("--title", help=("title for html chart"))
parser.add_argument(
"--nsys_cmd",
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
default="nsys",
)
args = parser.parse_args()
gputrace = GPUTrace2Graph()
gputrace.gen_graph(
args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported
)
if __name__ == "__main__":
main()
{
"sglang": {
"llama": {
"gemm|nvjet": "gemm",
"fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm",
"moe|sigmoid": "moe",
"CatArrayBatched|prepare_inputs": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"_norm_|Norm": "norm",
"topk": "topk",
"act_and_mul_": "activation",
"Rotary": "rope",
"SoftMax": "softmax",
"flash|fmha": "attn",
"elementwise": "elementwise",
"fp8_quant|cvt_|quantize": "quantize",
"reduce_kernel": "reduce",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"ds": {
"block_fp8_matmul": "block_fp8_gemm",
"gemm|matmul|nvjet": "gemm",
"fused_moe_kernel": "moe_gemm",
"moe|expert|sigmoid": "moe",
"CatArrayBatched|write_req_to": "prepare_next",
"ncclDevKernel|cross_device_reduce|all_gather": "nccl_and_custom_ar",
"Norm": "norm",
"topk": "topk",
"activation|act_and_mul": "activation",
"compute_position_kernel": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|quantize": "quantize",
"SoftMax": "softmax",
"reduce": "reduce",
"_fwd_|create_flash|::mla::|KVCache": "attn",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"gpt-oss": {
"gemm|nvjet": "gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm",
"moe|sigmoid": "moe",
"CatArrayBatched|prepare_inputs": "prepare_next",
"_norm_|Norm": "norm",
"ncclDevKernel|cross_device_reduce|allreduce": "nccl_and_custom_ar",
"topk|TopK": "topk",
"act_and_mul_": "activation",
"Rotary": "rope",
"SoftMax": "softmax",
"flash|fmha": "attn",
"elementwise": "elementwise",
"fp8_quant|cvt_|quantize": "quantize",
"reduce_kernel": "reduce",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
# Runtime examples
The below examples will mostly need you to start a server in a separate terminal before you can execute them. Please see in the code for detailed instruction.
## Native API
* `lora.py`: An example how to use LoRA adapters.
* `multimodal_embedding.py`: An example how perform [multi modal embedding](Alibaba-NLP/gme-Qwen2-VL-2B-Instruct).
* `openai_batch_chat.py`: An example how to process batch requests for chat completions.
* `openai_batch_complete.py`: An example how to process batch requests for text completions.
* **`openai_chat_with_response_prefill.py`**:
An example that demonstrates how to [prefill a response](https://eugeneyan.com/writing/prompting/#prefill-claudes-responses) using the OpenAI API by enabling the `continue_final_message` parameter.
When enabled, the final (partial) assistant message is removed and its content is used as a prefill so that the model continues that message rather than starting a new turn. See [Anthropic's prefill example](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/prefill-claudes-response#example-structured-data-extraction-with-prefilling) for more context.
* `reward_model.py`: An example how to extract scores from a reward model.
* `vertex_predict.py`: An example how to deploy a model to [Vertex AI](https://cloud.google.com/vertex-ai?hl=en).
## Engine
The `engine` folder contains that examples that show how to use [Offline Engine API](https://docs.sglang.ai/backend/offline_engine_api.html#Offline-Engine-API) for common workflows.
* `custom_server.py`: An example how to deploy a custom server.
* `embedding.py`: An example how to extract embeddings.
* `launch_engine.py`: An example how to launch the Engine.
* `offline_batch_inference_eagle.py`: An example how to perform speculative decoding using [EAGLE](https://docs.sglang.ai/backend/speculative_decoding.html).
* `offline_batch_inference_torchrun.py`: An example how to perform inference using [torchrun](https://pytorch.org/docs/stable/elastic/run.html).
* `offline_batch_inference_vlm.py`: An example how to use VLMs with the engine.
* `offline_batch_inference.py`: An example how to use the engine to perform inference on a batch of examples.
## Hidden States
The `hidden_states` folder contains examples on how to extract hidden states using SGLang. Please note that this might degrade throughput due to cuda graph rebuilding.
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
## Multimodal
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
## Token In, Token Out
The folder `token_in_token_out` shows how to perform inference, where we provide tokens and get tokens as response.
* `token_in_token_out_{llm|vlm}_{engine|server}.py`: Shows how to perform token in, token out workflow for llm/vlm using either the engine or native API.
from sanic import Sanic, text
from sanic.response import json
import sglang as sgl
engine = None
# Create an instance of the Sanic app
app = Sanic("sanic-server")
# Define an asynchronous route handler
@app.route("/generate", methods=["POST"])
async def generate(request):
prompt = request.json.get("prompt")
if not prompt:
return json({"error": "Prompt is required"}, status=400)
# async_generate returns a dict
result = await engine.async_generate(prompt)
return text(result["text"])
@app.route("/generate_stream", methods=["POST"])
async def generate_stream(request):
prompt = request.json.get("prompt")
if not prompt:
return json({"error": "Prompt is required"}, status=400)
# async_generate returns a dict
result = await engine.async_generate(prompt, stream=True)
# https://sanic.dev/en/guide/advanced/streaming.md#streaming
# init the response
response = await request.respond()
# result is an async generator
async for chunk in result:
await response.send(chunk["text"])
await response.eof()
def run_server():
global engine
engine = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
app.run(host="0.0.0.0", port=8000, single_process=True)
if __name__ == "__main__":
run_server()
import sglang as sgl
def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
llm = sgl.Engine(
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", is_embedding=True
)
outputs = llm.encode(prompts)
# Print the outputs (embedding vectors)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nEmbedding vector: {output['embedding']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
"""
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.
Starts the server, sends requests to it, and prints responses.
Usage:
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
"""
import os
import subprocess
import time
from contextlib import asynccontextmanager
import requests
from fastapi import FastAPI, Request
import sglang as sgl
from sglang.utils import terminate_process
engine = None
# Use FastAPI's lifespan manager to initialize/shutdown the engine
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manages SGLang engine initialization during server startup."""
global engine
# Initialize the SGLang engine when the server starts
# Adjust model_path and other engine arguments as needed
print("Loading SGLang engine...")
engine = sgl.Engine(
model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
)
print("SGLang engine loaded.")
yield
# Clean up engine resources when the server stops (optional, depends on engine needs)
print("Shutting down SGLang engine...")
# engine.shutdown() # Or other cleanup if available/necessary
print("SGLang engine shutdown.")
app = FastAPI(lifespan=lifespan)
@app.post("/generate")
async def generate_text(request: Request):
"""FastAPI endpoint to handle text generation requests."""
global engine
if not engine:
return {"error": "Engine not initialized"}, 503
try:
data = await request.json()
prompt = data.get("prompt")
max_new_tokens = data.get("max_new_tokens", 128)
temperature = data.get("temperature", 0.7)
if not prompt:
return {"error": "Prompt is required"}, 400
# Use async_generate for non-blocking generation
state = await engine.async_generate(
prompt,
sampling_params={
"max_new_tokens": max_new_tokens,
"temperature": temperature,
},
# Add other parameters like stop, top_p etc. as needed
)
return {"generated_text": state["text"]}
except Exception as e:
return {"error": str(e)}, 500
# Helper function to start the server
def start_server(args, timeout=60):
"""Starts the Uvicorn server as a subprocess and waits for it to be ready."""
base_url = f"http://{args.host}:{args.port}"
command = [
"python",
"-m",
"uvicorn",
"fastapi_engine_inference:app",
f"--host={args.host}",
f"--port={args.port}",
]
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start_time < timeout:
try:
# Check the /docs endpoint which FastAPI provides by default
response = session.get(
f"{base_url}/docs", timeout=5
) # Add a request timeout
if response.status_code == 200:
print(f"Server {base_url} is ready (responded on /docs)")
return process
except requests.ConnectionError:
# Specific exception for connection refused/DNS error etc.
pass
except requests.Timeout:
# Specific exception for request timeout
print(f"Health check to {base_url}/docs timed out, retrying...")
pass
except requests.RequestException as e:
# Catch other request exceptions
print(f"Health check request error: {e}, retrying...")
pass
# Use a shorter sleep interval for faster startup detection
time.sleep(1)
# If loop finishes, raise the timeout error
# Attempt to terminate the failed process before raising
if process:
print(
"Server failed to start within timeout, attempting to terminate process..."
)
terminate_process(process) # Use the imported terminate_process
raise TimeoutError(
f"Server failed to start at {base_url} within the timeout period."
)
def send_requests(server_url, prompts, max_new_tokens, temperature):
"""Sends generation requests to the running server for a list of prompts."""
# Iterate through prompts and send requests
for i, prompt in enumerate(prompts):
print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
payload = {
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
}
try:
response = requests.post(f"{server_url}/generate", json=payload, timeout=60)
result = response.json()
print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")
except requests.exceptions.Timeout:
print(f" Error: Request timed out for prompt '{prompt}'")
except requests.exceptions.RequestException as e:
print(f" Error sending request for prompt '{prompt}': {e}")
if __name__ == "__main__":
"""Main entry point for the script."""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
parser.add_argument("--tp_size", type=int, default=1)
args = parser.parse_args()
# Pass the model to the child uvicorn process via an env var
os.environ["MODEL_PATH"] = args.model_path
os.environ["TP_SIZE"] = str(args.tp_size)
# Start the server
process = start_server(args)
# Define the prompts and sampling parameters
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
max_new_tokens = 64
temperature = 0.1
# Define server url
server_url = f"http://{args.host}:{args.port}"
# Send requests to the server
send_requests(server_url, prompts, max_new_tokens, temperature)
# Terminate the server process
terminate_process(process)
"""
This example demonstrates how to launch the offline engine.
"""
import sglang as sgl
def main():
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
llm.generate("What is the capital of France?")
llm.shutdown()
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
"""
Usage:
python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct
"""
import argparse
import dataclasses
import sglang as sgl
from sglang.srt.server_args import ServerArgs
def main(
server_args: ServerArgs,
):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Create an LLM.
llm = sgl.Engine(**dataclasses.asdict(server_args))
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
"""
Usage:
python offline_batch_inference_async.py --model-path Qwen/Qwen2-VL-7B-Instruct
Note:
This demo shows the usage of async generation,
which is useful to implement an online-like generation with batched inference.
"""
import argparse
import asyncio
import dataclasses
import time
import sglang as sgl
from sglang.srt.server_args import ServerArgs
class InferenceEngine:
def __init__(self, **kwargs):
self.engine = sgl.Engine(**kwargs)
async def generate(self, prompt, sampling_params):
result = await self.engine.async_generate(prompt, sampling_params)
return result
async def run_server(server_args):
inference = InferenceEngine(**dataclasses.asdict(server_args))
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 100
# Create a sampling params object.
sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Run the generation tasks concurrently in async mode.
tasks = []
for prompt in prompts:
task = asyncio.create_task(inference.generate(prompt, sampling_params))
tasks.append(task)
# Get and print the result
for task in tasks:
await task
while True:
if not task.done():
time.sleep(1)
else:
result = task.result()
print(f"Generated text: {result['text']}")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
asyncio.run(run_server(server_args))
import sglang as sgl
def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = {"temperature": 0, "max_new_tokens": 30}
# Create an LLM.
llm = sgl.Engine(
model_path="meta-llama/Llama-2-7b-chat-hf",
speculative_algorithm="EAGLE",
speculative_draft_model_path="lmsys/sglang-EAGLE-llama2-chat-7B",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
cuda_graph_max_bs=8,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
"""
Usage:
python3 offline_batch_inference.py
"""
from urllib.request import urlopen
import sglang as sgl
def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt",
timeout=5,
) as response:
prompt = response.read().decode("utf-8")
return prompt
# Processing the prompt.
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = {
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"repetition_penalty": 1.05,
"max_new_tokens": 256,
}
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output["meta_info"]["prompt_tokens"]
generated_text = output["text"]
print(
f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
)
# Create an LLM.
def initialize_engine() -> sgl.Engine:
llm = sgl.Engine(
model_path="Qwen/Qwen2.5-7B-Instruct-1M",
context_length=1048576,
page_size=256,
attention_backend="dual_chunk_flash_attn",
tp_size=4,
disable_radix_cache=True,
enable_mixed_chunk=False,
enable_torch_compile=False,
chunked_prefill_size=131072,
mem_fraction_static=0.6,
log_level="DEBUG",
)
return llm
def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])
if __name__ == "__main__":
main()
"""
Usage:
python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct
"""
import argparse
import dataclasses
import sglang as sgl
from sglang.srt.parser.conversation import chat_templates
from sglang.srt.server_args import ServerArgs
def main(
server_args: ServerArgs,
):
vlm = sgl.Engine(**dataclasses.asdict(server_args))
conv = chat_templates[server_args.chat_template].copy()
image_token = conv.image_token
image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
prompt = f"What's in this image?\n{image_token}"
sampling_params = {
"temperature": 0.001,
"max_new_tokens": 30,
}
output = vlm.generate(
prompt=prompt,
image_data=image_url,
sampling_params=sampling_params,
)
print("===============================")
print(f"Prompt: {prompt}")
print(f"Generated text: {output['text']}")
vlm.shutdown()
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
# SGLang Engine
SGLang provides a direct inference engine without the need for an HTTP server. There are generally these use cases:
- [Offline Batch Inference](#offline-batch-inference)
- [Embedding Generation](#embedding-generation)
- [Custom Server](#custom-server)
- [Token-In-Token-Out for RLHF](#token-in-token-out-for-rlhf)
- [Inference Using FastAPI](#inference-using-fastapi)
## Examples
### [Offline Batch Inference](./offline_batch_inference.py)
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
### [Embedding Generation](./embedding.py)
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
### [Custom Server](./custom_server.py)
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
#### Steps
1. Install Sanic:
```bash
pip install sanic
```
2. Run the server:
```bash
python custom_server
```
3. Send requests:
```bash
curl -X POST http://localhost:8000/generate -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}'
curl -X POST http://localhost:8000/generate_stream -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}' --no-buffer
```
This will send both non-streaming and streaming requests to the server.
### [Token-In-Token-Out for RLHF](../token_in_token_out)
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
### [Inference Using FastAPI](fastapi_engine_inference.py)
This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_remote_state.py \
--model-path /path/to/load \
--tensor-parallel-size 8 \
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
Then, the model can be loaded with
llm = Engine(
model_path="[protocol]://[host]:[port]/[model_name]",
tensor_parallel_size=8,
)
"""
import dataclasses
from argparse import ArgumentParser
from pathlib import Path
from sglang import Engine, ServerArgs
parser = ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--remote-model-save-url",
required=True,
type=str,
help="remote address to store model weights",
)
parser.add_argument(
"--remote-draft-model-save-url",
default=None,
type=str,
help="remote address to store draft model weights",
)
def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
llm.save_remote_model(
url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url
)
print("save remote (draft) model successfully")
if __name__ == "__main__":
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py \
--model-path /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from argparse import ArgumentParser
from pathlib import Path
from sglang import Engine, ServerArgs
parser = ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--output", "-o", required=True, type=str, help="path to output checkpoint"
)
parser.add_argument(
"--file-pattern", type=str, help="string pattern of saved filenames"
)
parser.add_argument(
"--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file",
)
def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
Path(args.output).mkdir(exist_ok=True)
llm.save_sharded_model(
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(
os.path.join(model_path, file), os.path.join(args.output, file)
)
else:
shutil.copy(os.path.join(model_path, file), args.output)
if __name__ == "__main__":
args = parser.parse_args()
main(args)
"""
Usage:
python hidden_states.py
Note that each time you change the `return_hidden_states` parameter,
the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately.
"""
import torch
import sglang as sgl
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
llm = sgl.Engine(
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
enable_return_hidden_states=True,
)
sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"max_new_tokens": 10,
}
outputs = llm.generate(
prompts, sampling_params=sampling_params, return_hidden_states=True
)
llm.shutdown()
for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================")
print(
f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
)
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print()
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
"""
Usage:
python hidden_states_server.py
Note that each time you change the `return_hidden_states` parameter,
the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately.
"""
import requests
import torch
from sglang.test.test_utils import is_in_ci
from sglang.utils import terminate_process, wait_for_server
if is_in_ci():
from docs.backend.patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
def main():
# Launch the server
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"max_new_tokens": 10,
}
json_data = {
"text": prompts,
"sampling_params": sampling_params,
"return_hidden_states": True,
}
response = requests.post(
f"http://localhost:{port}/generate",
json=json_data,
)
terminate_process(server_process)
outputs = response.json()
for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================")
print(
f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
)
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment