Unverified Commit 2d72fc47 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve profiler and integrate profiler in bench_one_batch_server (#6787)

parent b520d028
...@@ -23,8 +23,8 @@ If you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the ser ...@@ -23,8 +23,8 @@ If you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the ser
The case of server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings. The case of server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings.
On the other hand, if you see `token usage` very high and you frequently see warnings like On the other hand, if you see `token usage` very high and you frequently see warnings like
`decode out of memory happened, #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000`, you can increase `--schedule-conservativeness` to a value like 1.3. `KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000`, you can increase `--schedule-conservativeness` to a value like 1.3.
If you see `decode out of memory happened` occasionally but not frequently, it is okay. If you see `KV cache pool is full. Retract requests.` occasionally but not frequently, it is okay.
### Tune `--dp-size` and `--tp-size` ### Tune `--dp-size` and `--tp-size`
......
...@@ -8,6 +8,7 @@ Usage: ...@@ -8,6 +8,7 @@ Usage:
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
""" """
import argparse import argparse
...@@ -19,10 +20,10 @@ import os ...@@ -19,10 +20,10 @@ import os
import time import time
from typing import Tuple from typing import Tuple
import numpy as np
import requests import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -42,6 +43,8 @@ class BenchArgs: ...@@ -42,6 +43,8 @@ class BenchArgs:
base_url: str = "" base_url: str = ""
skip_warmup: bool = False skip_warmup: bool = False
show_report: bool = False show_report: bool = False
profile: bool = False
profile_by_stage: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -68,6 +71,8 @@ class BenchArgs: ...@@ -68,6 +71,8 @@ class BenchArgs:
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true") parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true") parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -93,8 +98,8 @@ def launch_server_process(server_args: ServerArgs): ...@@ -93,8 +98,8 @@ def launch_server_process(server_args: ServerArgs):
base_url = f"http://{server_args.host}:{server_args.port}" base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600 timeout = 600
start_time = time.perf_counter() start_time = time.time()
while time.perf_counter() - start_time < timeout: while time.time() - start_time < timeout:
try: try:
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
...@@ -119,6 +124,8 @@ def run_one_case( ...@@ -119,6 +124,8 @@ def run_one_case(
run_name: str, run_name: str,
result_filename: str, result_filename: str,
tokenizer, tokenizer,
profile: bool = False,
profile_by_stage: bool = False,
): ):
requests.post(url + "/flush_cache") requests.post(url + "/flush_cache")
input_requests = sample_random_requests( input_requests = sample_random_requests(
...@@ -145,6 +152,12 @@ def run_one_case( ...@@ -145,6 +152,12 @@ def run_one_case(
else: else:
json_schema = None json_schema = None
profile_link = None
if profile:
profile_link: str = run_profile(
url, 3, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter() tic = time.perf_counter()
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
...@@ -194,8 +207,8 @@ def run_one_case( ...@@ -194,8 +207,8 @@ def run_one_case(
print(f"output_len: {output_len}") print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s") print(f"latency: {latency:.2f} s")
print(f"ttft: {ttft:.2f} s") print(f"ttft: {ttft:.2f} s")
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s") print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"Input throughput: {input_throughput:.2f} tok/s") print(f"input throughput: {input_throughput:.2f} tok/s")
if output_len != 1: if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s") print(f"output throughput: {output_throughput:.2f} tok/s")
...@@ -222,6 +235,7 @@ def run_one_case( ...@@ -222,6 +235,7 @@ def run_one_case(
overall_throughput, overall_throughput,
last_gen_throughput, last_gen_throughput,
acc_length, acc_length,
profile_link if profile else None,
) )
...@@ -253,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -253,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# benchmark # benchmark
result = [] result = []
bench_result = []
try: try:
for bs, il, ol in itertools.product( for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len bench_args.batch_size, bench_args.input_len, bench_args.output_len
...@@ -271,6 +286,33 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -271,6 +286,33 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
tokenizer=tokenizer, tokenizer=tokenizer,
) )
) )
if bench_args.profile:
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bench_result.append(
(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
profile=bench_args.profile,
profile_by_stage=bench_args.profile_by_stage,
)[-1],
)
)
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
except Exception as e:
print(f"Error profiling, there will be no profile trace dump: {e}")
finally: finally:
if proc: if proc:
kill_process_tree(proc.pid) kill_process_tree(proc.pid)
...@@ -280,8 +322,20 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -280,8 +322,20 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if not bench_args.show_report: if not bench_args.show_report:
return return
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n" summary = (
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n" f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
if bench_args.profile:
summary += " profile |"
summary += "\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
if bench_args.profile:
summary += "-------------|"
summary += "\n"
for ( for (
batch_size, batch_size,
...@@ -292,6 +346,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -292,6 +346,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
overall_throughput, overall_throughput,
last_gen_throughput, last_gen_throughput,
acc_length, acc_length,
trace_link,
) in result: ) in result:
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100 hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
input_util = 0.7 input_util = 0.7
...@@ -304,17 +359,18 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -304,17 +359,18 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
f"{accept_length} | " f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | " f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | " f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n" f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
) )
if trace_link:
line += f" [Profile]({trace_link}) |"
line += "\n"
summary += line summary += line
# print metrics table # print metrics table
print(summary) print(summary)
if is_in_ci(): if is_in_ci():
write_github_step_summary( write_github_step_summary(summary)
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
)
if __name__ == "__main__": if __name__ == "__main__":
......
"""
Run live profiling.
Usage:
python3 -m sglang.profiler
"""
import argparse
import json
import os
import time
import urllib.parse
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional
import requests
PARENT_FOLDER = "/tmp/sglang-profile"
def _run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
) -> str:
if output_dir is None:
output_dir = PARENT_FOLDER
output_dir = os.path.normpath(output_dir)
output_dir = os.path.abspath(output_dir)
output_dir = Path(output_dir)
# Add "profile_name/timestamp" to the path.
if profile_name:
output_dir = output_dir / profile_name
output_dir = output_dir / str(time.time())
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Dump profiling traces to {output_dir}")
print(
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
)
# Dump server args.
file_path = Path(output_dir) / "server_args.json"
if not file_path.exists():
response = requests.get(url + "/get_server_info")
response.raise_for_status()
server_args_data = response.json()
with open(file_path, "w") as file:
file.write(json.dumps(server_args_data))
# Start profiler. The API replies when all steps are processed
# and files are generated.
json_data = {
"output_dir": str(output_dir),
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
}
response = requests.post(url=url + "/start_profile", json=json_data)
response.raise_for_status()
trace_link = str(output_dir)
return trace_link
def run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
)
return link
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--url",
type=str,
default="http://localhost:30000",
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Profile directory to dump profile traces.",
)
parser.add_argument(
"--profile-name",
type=str,
default=None,
help="The name of this profile run.",
)
parser.add_argument(
"--num-steps",
type=int,
default=5,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--profile-by-stage",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--cpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile CPU activity",
)
parser.add_argument(
"--gpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile GPU activity",
)
parser.add_argument(
"--mem",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to memory usage (https://pytorch.org/memory_viz)",
)
parser.add_argument(
"--rpd",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
args = parser.parse_args()
activities = []
if args.cpu:
activities.append("CPU")
if args.gpu:
activities.append("GPU")
if args.mem:
activities.append("MEM")
if args.rpd:
activities.append("RPD")
run_profile(
args.url,
args.num_steps,
activities,
args.output_dir,
args.profile_name,
args.profile_by_stage,
)
...@@ -514,9 +514,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -514,9 +514,7 @@ def _set_envs_and_config(server_args: ServerArgs):
pid, exitcode = os.waitpid(0, os.WNOHANG) pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0: if exitcode != 0:
logger.warning( logger.warning(
"Child process unexpectedly failed with an exit code %d. pid=%d", f"Child process unexpectedly failed with {exitcode=}. {pid=}"
exitcode,
pid,
) )
signal.signal(signal.SIGCHLD, sigchld_handler) signal.signal(signal.SIGCHLD, sigchld_handler)
......
...@@ -350,6 +350,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None): ...@@ -350,6 +350,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
activities=obj.activities, activities=obj.activities,
with_stack=obj.with_stack, with_stack=obj.with_stack,
record_shapes=obj.record_shapes, record_shapes=obj.record_shapes,
profile_by_stage=obj.profile_by_stage,
) )
return Response( return Response(
content="Start profiling.\n", content="Start profiling.\n",
......
...@@ -401,7 +401,6 @@ def compute_initial_expert_location_metadata( ...@@ -401,7 +401,6 @@ def compute_initial_expert_location_metadata(
) -> ExpertLocationMetadata: ) -> ExpertLocationMetadata:
data = server_args.init_expert_location data = server_args.init_expert_location
if data == "trivial": if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config) return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function # TODO unify with the utils function
......
...@@ -848,7 +848,8 @@ class ProfileReqInput: ...@@ -848,7 +848,8 @@ class ProfileReqInput:
# If it is set, profiling is automatically stopped after this step, and # If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile. # the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None num_steps: Optional[int] = None
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None activities: Optional[List[str]] = None
profile_by_stage: bool = False
with_stack: Optional[bool] = None with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None record_shapes: Optional[bool] = None
...@@ -875,6 +876,7 @@ class ProfileReq: ...@@ -875,6 +876,7 @@ class ProfileReq:
output_dir: Optional[str] = None output_dir: Optional[str] = None
num_steps: Optional[int] = None num_steps: Optional[int] = None
activities: Optional[List[str]] = None activities: Optional[List[str]] = None
profile_by_stage: bool = False
with_stack: Optional[bool] = None with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None record_shapes: Optional[bool] = None
profile_id: Optional[str] = None profile_id: Optional[str] = None
......
...@@ -34,7 +34,6 @@ import zmq ...@@ -34,7 +34,6 @@ import zmq
from torch.distributed import barrier from torch.distributed import barrier
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import ( from sglang.srt.disaggregation.decode import (
...@@ -63,7 +62,6 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -63,7 +62,6 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ( from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder, get_global_expert_distribution_recorder,
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -140,6 +138,7 @@ from sglang.srt.utils import ( ...@@ -140,6 +138,7 @@ from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
disable_request_logging, disable_request_logging,
get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
...@@ -213,7 +212,6 @@ class Scheduler( ...@@ -213,7 +212,6 @@ class Scheduler(
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info( compute_dp_attention_world_info(
...@@ -333,12 +331,16 @@ class Scheduler( ...@@ -333,12 +331,16 @@ class Scheduler(
# Print debug info # Print debug info
if tp_rank == 0: if tp_rank == 0:
avail_mem = get_available_gpu_memory(
self.device, self.gpu_id, empty_cache=False
)
logger.info( logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, " f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, " f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}, "
f"available_gpu_mem={avail_mem:.2f} GB"
) )
# Init memory pool and cache # Init memory pool and cache
...@@ -362,6 +364,7 @@ class Scheduler( ...@@ -362,6 +364,7 @@ class Scheduler(
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init session info # Init session info
self.sessions: Dict[str, Session] = {} self.sessions: Dict[str, Session] = {}
...@@ -425,8 +428,14 @@ class Scheduler( ...@@ -425,8 +428,14 @@ class Scheduler(
self.profiler_activities: Optional[List[str]] = None self.profiler_activities: Optional[List[str]] = None
self.profiler_id: Optional[str] = None self.profiler_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.forward_sleep_time = None self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats # Init metrics stats
self.init_metrics() self.init_metrics()
...@@ -1518,7 +1527,7 @@ class Scheduler( ...@@ -1518,7 +1527,7 @@ class Scheduler(
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
logger.info( logger.info(
"Decode out of memory happened. " "KV cache pool is full. Retract requests. "
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
) )
...@@ -1542,13 +1551,8 @@ class Scheduler( ...@@ -1542,13 +1551,8 @@ class Scheduler(
"""Run a batch.""" """Run a batch."""
self.forward_ct += 1 self.forward_ct += 1
# Check profiler # Whether to run the profiler
if ( self._profile_batch_predicate(batch)
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.send_to_tokenizer.send_pyobj(self.stop_profile())
if self.forward_sleep_time is not None: if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s") logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time) time.sleep(self.forward_sleep_time)
...@@ -2121,46 +2125,82 @@ class Scheduler( ...@@ -2121,46 +2125,82 @@ class Scheduler(
def profile(self, recv_req: ProfileReq): def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE: if recv_req.type == ProfileReqType.START_PROFILE:
return self.start_profile( if recv_req.profile_by_stage:
recv_req.output_dir, return self.init_profile(
recv_req.num_steps, recv_req.output_dir,
recv_req.activities, recv_req.num_steps,
recv_req.with_stack, recv_req.activities,
recv_req.record_shapes, recv_req.with_stack,
recv_req.profile_id, recv_req.record_shapes,
) recv_req.profile_by_stage,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
)
return self.start_profile(True)
else: else:
return self.stop_profile() return self.stop_profile()
def start_profile( def init_profile(
self, self,
output_dir: Optional[str], output_dir: Optional[str],
num_steps: Optional[int], num_steps: Optional[int],
activities: Optional[List[str]], activities: Optional[List[str]],
with_stack: Optional[bool], with_stack: Optional[bool],
record_shapes: Optional[bool], record_shapes: Optional[bool],
profile_id: Optional[str], profile_by_stage: bool,
) -> None: ) -> ProfileReqOutput:
if self.profiler_activities: if self.profile_in_progress:
return ProfileReqOutput( return ProfileReqOutput(
success=False, success=False,
message="Profiling is already in progress. Call /stop_profile first.", message="Profiling is already in progress. Call /stop_profile first.",
) )
self.profile_by_stage = profile_by_stage
if output_dir is None: if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None: if activities is None:
activities = ["CPU", "GPU"] activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities self.profiler_activities = activities
self.profiler_id = profile_id
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info( logger.info(
"Profiling starts. Traces will be saved to: %s (with id %s)", f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir}",
self.torch_profiler_output_dir,
self.profiler_id,
) )
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = { activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU, "CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA, "GPU": torch.profiler.ProfilerActivity.CUDA,
...@@ -2169,48 +2209,97 @@ class Scheduler( ...@@ -2169,48 +2209,97 @@ class Scheduler(
activity_map[a] for a in activities if a in activity_map activity_map[a] for a in activities if a in activity_map
] ]
if torchprof_activities: if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile( self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities, activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True, with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False, record_shapes=record_shapes if record_shapes is not None else False,
) )
self.torch_profiler.start() self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities: if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000) torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities: if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
if num_steps: return ProfileReqOutput(success=True, message="Succeeded")
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(self) -> None: def stop_profile(
if self.profiler_activities is None: self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput( return ProfileReqOutput(
success=False, success=False,
message="Profiling is not in progress. Call /start_profile first.", message="Profiling is not in progress. Call /start_profile first.",
) )
logger.info("Stop profiling...") stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None: if self.torch_profiler is not None:
self.torch_profiler.stop() self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace( self.torch_profiler.export_chrome_trace(
os.path.join( os.path.join(
self.torch_profiler_output_dir, self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz", str(time.time())
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
) )
) )
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
if "MEM" in self.profiler_activities: torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join( memory_profile_path = os.path.join(
self.torch_profiler_output_dir, self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle", str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
) )
torch.cuda.memory._dump_snapshot(memory_profile_path) torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None) torch.cuda.memory._record_memory_history(enabled=None)
...@@ -2223,11 +2312,38 @@ class Scheduler( ...@@ -2223,11 +2312,38 @@ class Scheduler(
self.torch_profiler_output_dir, self.torch_profiler_output_dir,
) )
self.torch_profiler = None self.torch_profiler = None
self.torch_profiler_output_dir = None self.profile_in_progress = False
self.profiler_activities = None
self.profiler_target_forward_ct = None return ProfileReqOutput(success=True, message="Succeeded.")
return ProfileReqOutput(success=True, message="Succeeded") def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
else:
raise RuntimeError("unsupported profile stage")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: if recv_req == ExpertDistributionReq.START_RECORD:
......
...@@ -796,6 +796,7 @@ class TokenizerManager: ...@@ -796,6 +796,7 @@ class TokenizerManager:
activities: Optional[List[str]] = None, activities: Optional[List[str]] = None,
with_stack: Optional[bool] = None, with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None, record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
): ):
self.auto_create_handle_loop() self.auto_create_handle_loop()
req = ProfileReq( req = ProfileReq(
...@@ -805,6 +806,7 @@ class TokenizerManager: ...@@ -805,6 +806,7 @@ class TokenizerManager:
activities=activities, activities=activities,
with_stack=with_stack, with_stack=with_stack,
record_shapes=record_shapes, record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()), profile_id=str(time.time()),
) )
return await self._execute_profile(req) return await self._execute_profile(req)
......
...@@ -39,10 +39,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -39,10 +39,7 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors, PPProxyTensors,
) )
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
TboCudaGraphRunnerPlugin,
TboForwardBatchPreparer,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
get_device_memory_capacity, get_device_memory_capacity,
......
...@@ -77,11 +77,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner ...@@ -77,11 +77,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import ( from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
......
...@@ -1643,7 +1643,7 @@ def auto_choose_speculative_params(arch: str): ...@@ -1643,7 +1643,7 @@ def auto_choose_speculative_params(arch: str):
return (5, 4, 8) return (5, 4, 8)
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
# The default value for deepseek # The default value for deepseek
return (5, 4, 8) return (3, 1, 4)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]: elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8) return (5, 4, 8)
else: else:
......
...@@ -93,6 +93,11 @@ def is_in_ci(): ...@@ -93,6 +93,11 @@ def is_in_ci():
return get_bool_env_var("SGLANG_IS_IN_CI") return get_bool_env_var("SGLANG_IS_IN_CI")
def is_in_amd_ci():
"""Return whether it is in an AMD CI runner."""
return get_bool_env_var("SGLANG_AMD_CI")
if is_in_ci(): if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100 5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
......
...@@ -16,7 +16,8 @@ suites = { ...@@ -16,7 +16,8 @@ suites = {
TestFile("models/lora/test_lora.py", 76), TestFile("models/lora/test_lora.py", 76),
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/test_embedding_models.py", 184), TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_compressed_tensors_models.py", 42), TestFile("models/test_compressed_tensors_models.py", 42),
TestFile("models/test_generation_models.py", 103), TestFile("models/test_generation_models.py", 103),
...@@ -24,44 +25,43 @@ suites = { ...@@ -24,44 +25,43 @@ suites = {
# TestFile("models/test_grok_models.py", 60), # Disabled due to illegal memory access # TestFile("models/test_grok_models.py", 60), # Disabled due to illegal memory access
TestFile("models/test_qwen_models.py", 82), TestFile("models/test_qwen_models.py", 82),
TestFile("models/test_reward_models.py", 132), TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 317), TestFile("models/test_vlm_models.py", 437),
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22), TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_chunked_prefill.py", 285), TestFile("test_chunked_prefill.py", 313),
TestFile("test_eagle_infer.py", 584), TestFile("test_eagle_infer.py", 619),
TestFile("test_ebnf_constrained.py", 108), TestFile("test_ebnf_constrained.py", 108),
TestFile("test_enable_thinking.py", 70),
TestFile("test_embedding_openai_server.py", 141), TestFile("test_embedding_openai_server.py", 141),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376), TestFile("test_fa3.py", 376),
TestFile("test_fim_completion.py", 40), TestFile("test_flashmla.py", 352),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116), TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 254), TestFile("test_hicache_mla.py", 127),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8), TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
TestFile("test_json_constrained.py", 98), TestFile("test_json_constrained.py", 98),
TestFile("test_large_max_new_tokens.py", 41), TestFile("test_large_max_new_tokens.py", 41),
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 242), TestFile("test_mla.py", 167),
TestFile("test_mla_deepseek_v3.py", 221), TestFile("test_mla_deepseek_v3.py", 342),
TestFile("test_mla_int8_deepseek_v3.py", 389), TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 395), TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 153), TestFile("test_mla_fp8.py", 93),
TestFile("test_flashmla.py", 300),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 216), TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149), TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 167), TestFile("test_radix_attention.py", 105),
TestFile("test_reasoning_content.py", 89), TestFile("test_reasoning_content.py", 89),
TestFile("test_enable_thinking.py", 70),
TestFile("test_regex_constrained.py", 64), TestFile("test_regex_constrained.py", 64),
TestFile("test_release_memory_occupation.py", 44), TestFile("test_release_memory_occupation.py", 44),
TestFile("test_request_length_validation.py", 31), TestFile("test_request_length_validation.py", 31),
...@@ -70,13 +70,13 @@ suites = { ...@@ -70,13 +70,13 @@ suites = {
TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_engine.py", 261), TestFile("test_srt_engine.py", 261),
TestFile("test_srt_endpoint.py", 130), TestFile("test_srt_endpoint.py", 130),
TestFile("test_tool_choice.py", 120), TestFile("test_tool_choice.py", 226),
TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172), TestFile("test_torch_compile_moe.py", 172),
TestFile("test_torch_native_attention_backend.py", 123), TestFile("test_torch_native_attention_backend.py", 123),
TestFile("test_torchao.py", 70), TestFile("test_torchao.py", 70),
TestFile("test_triton_attention_kernels.py", 4), TestFile("test_triton_attention_kernels.py", 4),
TestFile("test_triton_attention_backend.py", 134), TestFile("test_triton_attention_backend.py", 150),
TestFile("test_triton_moe_channel_fp8_kernel.py", 25), TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
TestFile("test_triton_sliding_window.py", 250), TestFile("test_triton_sliding_window.py", 250),
TestFile("test_update_weights_from_disk.py", 114), TestFile("test_update_weights_from_disk.py", 114),
...@@ -84,10 +84,9 @@ suites = { ...@@ -84,10 +84,9 @@ suites = {
TestFile("test_vertex_endpoint.py", 31), TestFile("test_vertex_endpoint.py", 31),
TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vision_chunked_prefill.py", 175),
TestFile("test_vlm_input_format.py", 300), TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 700), TestFile("test_vision_openai_server_a.py", 584),
TestFile("test_vision_openai_server_b.py", 700), TestFile("test_vision_openai_server_b.py", 556),
TestFile("test_w8a8_quantization.py", 46), TestFile("test_w8a8_quantization.py", 46),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
], ],
"per-commit-amd": [ "per-commit-amd": [
TestFile("test_mla.py", 242), TestFile("test_mla.py", 242),
...@@ -119,9 +118,9 @@ suites = { ...@@ -119,9 +118,9 @@ suites = {
# TestFile("test_deepep_intranode.py", 50), # TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50), # TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250), # TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
TestFile("test_disaggregation.py", 210), TestFile("test_disaggregation.py", 270),
TestFile("test_disaggregation_different_tp.py", 210), TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_full_deepseek_v3.py", 250), TestFile("test_full_deepseek_v3.py", 463),
], ],
"per-commit-8-gpu-amd": [ "per-commit-8-gpu-amd": [
TestFile("test_full_deepseek_v3.py", 250), TestFile("test_full_deepseek_v3.py", 250),
...@@ -133,11 +132,11 @@ suites = { ...@@ -133,11 +132,11 @@ suites = {
TestFile("test_nightly_gsm8k_eval_amd.py"), TestFile("test_nightly_gsm8k_eval_amd.py"),
], ],
"vllm_dependency_test": [ "vllm_dependency_test": [
TestFile("test_vllm_dependency.py"),
TestFile("test_awq.py"), TestFile("test_awq.py"),
TestFile("test_bnb.py"),
TestFile("test_gguf.py", 78), TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_gptqmodel_dynamic.py", 72),
TestFile("test_bnb.py"), TestFile("test_vllm_dependency.py"),
], ],
} }
......
...@@ -6,6 +6,7 @@ from sglang.test.test_utils import ( ...@@ -6,6 +6,7 @@ from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_amd_ci,
is_in_ci, is_in_ci,
run_bench_offline_throughput, run_bench_offline_throughput,
run_bench_one_batch, run_bench_one_batch,
...@@ -46,7 +47,7 @@ class TestBenchOneBatch(CustomTestCase): ...@@ -46,7 +47,7 @@ class TestBenchOneBatch(CustomTestCase):
f"### test_moe_tp2_bs1 (Mixtral-8x7B)\n" f"### test_moe_tp2_bs1 (Mixtral-8x7B)\n"
f"output_throughput: {output_throughput:.2f} token/s\n" f"output_throughput: {output_throughput:.2f} token/s\n"
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(output_throughput, 85) self.assertGreater(output_throughput, 85)
else: else:
self.assertGreater(output_throughput, 125) self.assertGreater(output_throughput, 125)
...@@ -62,7 +63,7 @@ class TestBenchOneBatch(CustomTestCase): ...@@ -62,7 +63,7 @@ class TestBenchOneBatch(CustomTestCase):
f"### test_torch_compile_tp2_bs1 (Mixtral-8x7B)\n" f"### test_torch_compile_tp2_bs1 (Mixtral-8x7B)\n"
f"output_throughput: {output_throughput:.2f} token/s\n" f"output_throughput: {output_throughput:.2f} token/s\n"
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(output_throughput, 200) self.assertGreater(output_throughput, 200)
else: else:
self.assertGreater(output_throughput, 220) self.assertGreater(output_throughput, 220)
......
import os
import unittest import unittest
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -8,8 +7,8 @@ from sglang.test.test_utils import ( ...@@ -8,8 +7,8 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_FP8, DEFAULT_MODEL_NAME_FOR_TEST_FP8,
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_amd_ci,
is_in_ci, is_in_ci,
run_bench_serving, run_bench_serving,
write_github_step_summary, write_github_step_summary,
...@@ -31,7 +30,7 @@ class TestBenchServing(CustomTestCase): ...@@ -31,7 +30,7 @@ class TestBenchServing(CustomTestCase):
f"### test_offline_throughput_default\n" f"### test_offline_throughput_default\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 3150) self.assertGreater(res["output_throughput"], 3150)
else: else:
self.assertGreater(res["output_throughput"], 3800) self.assertGreater(res["output_throughput"], 3800)
...@@ -69,7 +68,7 @@ class TestBenchServing(CustomTestCase): ...@@ -69,7 +68,7 @@ class TestBenchServing(CustomTestCase):
f"### test_offline_throughput_without_radix_cache\n" f"### test_offline_throughput_without_radix_cache\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 3050) self.assertGreater(res["output_throughput"], 3050)
else: else:
self.assertGreater(res["output_throughput"], 3800) self.assertGreater(res["output_throughput"], 3800)
...@@ -107,7 +106,7 @@ class TestBenchServing(CustomTestCase): ...@@ -107,7 +106,7 @@ class TestBenchServing(CustomTestCase):
f"### test_offline_throughput_with_triton_attention_backend\n" f"### test_offline_throughput_with_triton_attention_backend\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 3500) self.assertGreater(res["output_throughput"], 3500)
else: else:
self.assertGreater(res["output_throughput"], 3700) self.assertGreater(res["output_throughput"], 3700)
...@@ -125,7 +124,7 @@ class TestBenchServing(CustomTestCase): ...@@ -125,7 +124,7 @@ class TestBenchServing(CustomTestCase):
f"### test_offline_throughput_default_fp8\n" f"### test_offline_throughput_default_fp8\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 3500) self.assertGreater(res["output_throughput"], 3500)
else: else:
self.assertGreater(res["output_throughput"], 4300) self.assertGreater(res["output_throughput"], 4300)
...@@ -144,7 +143,7 @@ class TestBenchServing(CustomTestCase): ...@@ -144,7 +143,7 @@ class TestBenchServing(CustomTestCase):
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
) )
self.assertLess(res["median_e2e_latency_ms"], 11000) self.assertLess(res["median_e2e_latency_ms"], 11000)
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertLess(res["median_ttft_ms"], 115) self.assertLess(res["median_ttft_ms"], 115)
else: else:
self.assertLess(res["median_ttft_ms"], 86) self.assertLess(res["median_ttft_ms"], 86)
...@@ -167,7 +166,7 @@ class TestBenchServing(CustomTestCase): ...@@ -167,7 +166,7 @@ class TestBenchServing(CustomTestCase):
f"### test_vlm_offline_throughput\n" f"### test_vlm_offline_throughput\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 2000) self.assertGreater(res["output_throughput"], 2000)
# TODO: not set yet, need AMD machine # TODO: not set yet, need AMD machine
else: else:
...@@ -191,7 +190,7 @@ class TestBenchServing(CustomTestCase): ...@@ -191,7 +190,7 @@ class TestBenchServing(CustomTestCase):
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
) )
self.assertLess(res["median_e2e_latency_ms"], 16500) self.assertLess(res["median_e2e_latency_ms"], 16500)
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertLess(res["median_ttft_ms"], 150) self.assertLess(res["median_ttft_ms"], 150)
# TODO: not set yet, need AMD machine # TODO: not set yet, need AMD machine
else: else:
...@@ -230,7 +229,7 @@ class TestBenchServing(CustomTestCase): ...@@ -230,7 +229,7 @@ class TestBenchServing(CustomTestCase):
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length: {res["accept_length"]:.2f} \n' f'accept_length: {res["accept_length"]:.2f} \n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertLess(res["median_e2e_latency_ms"], 1800) self.assertLess(res["median_e2e_latency_ms"], 1800)
else: else:
self.assertLess(res["median_e2e_latency_ms"], 900) self.assertLess(res["median_e2e_latency_ms"], 900)
...@@ -249,7 +248,7 @@ class TestBenchServing(CustomTestCase): ...@@ -249,7 +248,7 @@ class TestBenchServing(CustomTestCase):
f"### test_moe_offline_throughput_default\n" f"### test_moe_offline_throughput_default\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 2100) self.assertGreater(res["output_throughput"], 2100)
else: else:
self.assertGreater(res["output_throughput"], 2200) self.assertGreater(res["output_throughput"], 2200)
...@@ -267,7 +266,7 @@ class TestBenchServing(CustomTestCase): ...@@ -267,7 +266,7 @@ class TestBenchServing(CustomTestCase):
f"### test_moe_offline_throughput_without_radix_cache\n" f"### test_moe_offline_throughput_without_radix_cache\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(res["output_throughput"], 2100) self.assertGreater(res["output_throughput"], 2100)
else: else:
self.assertGreater(res["output_throughput"], 2200) self.assertGreater(res["output_throughput"], 2200)
...@@ -289,7 +288,7 @@ class TestBenchServing(CustomTestCase): ...@@ -289,7 +288,7 @@ class TestBenchServing(CustomTestCase):
f"### test_pp_offline_throughput_default_decode\n" f"### test_pp_offline_throughput_default_decode\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n' f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
) )
self.assertGreater(res["output_throughput"], 7500) self.assertGreater(res["output_throughput"], 6700)
def test_pp_long_context_prefill(self): def test_pp_long_context_prefill(self):
res = run_bench_serving( res = run_bench_serving(
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -11,6 +10,7 @@ from sglang.test.test_utils import ( ...@@ -11,6 +10,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_amd_ci,
is_in_ci, is_in_ci,
popen_launch_server, popen_launch_server,
write_github_step_summary, write_github_step_summary,
...@@ -67,7 +67,7 @@ class TestDeepseekV3(CustomTestCase): ...@@ -67,7 +67,7 @@ class TestDeepseekV3(CustomTestCase):
write_github_step_summary( write_github_step_summary(
f"### test_bs_1_speed (deepseek-v3)\n" f"{speed=:.2f} token/s\n" f"### test_bs_1_speed (deepseek-v3)\n" f"{speed=:.2f} token/s\n"
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(speed, 12) self.assertGreater(speed, 12)
else: else:
self.assertGreater(speed, 75) self.assertGreater(speed, 75)
...@@ -91,7 +91,7 @@ class TestDeepseekV3MTP(CustomTestCase): ...@@ -91,7 +91,7 @@ class TestDeepseekV3MTP(CustomTestCase):
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
] ]
if os.environ.get("SGLANG_AMD_CI") != "1": if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"] other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
...@@ -148,11 +148,11 @@ class TestDeepseekV3MTP(CustomTestCase): ...@@ -148,11 +148,11 @@ class TestDeepseekV3MTP(CustomTestCase):
f"{acc_length=:.2f}\n" f"{acc_length=:.2f}\n"
f"{speed=:.2f} token/s\n" f"{speed=:.2f} token/s\n"
) )
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(acc_length, 2.8) self.assertGreater(acc_length, 2.8)
else: else:
self.assertGreater(acc_length, 2.9) self.assertGreater(acc_length, 2.9)
if os.getenv("SGLANG_AMD_CI") == "1": if is_in_amd_ci():
self.assertGreater(speed, 15) self.assertGreater(speed, 15)
else: else:
self.assertGreater(speed, 105) self.assertGreater(speed, 105)
......
...@@ -24,8 +24,8 @@ class TestMLA(CustomTestCase): ...@@ -24,8 +24,8 @@ class TestMLA(CustomTestCase):
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--enable-torch-compile", "--enable-torch-compile",
"--cuda-graph-max-bs", "--torch-compile-max-bs",
"2", "4",
"--chunked-prefill-size", "--chunked-prefill-size",
"256", "256",
], ],
...@@ -35,18 +35,6 @@ class TestMLA(CustomTestCase): ...@@ -35,18 +35,6 @@ class TestMLA(CustomTestCase):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
......
...@@ -57,50 +57,6 @@ class TestFlashinferMLA(CustomTestCase): ...@@ -57,50 +57,6 @@ class TestFlashinferMLA(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLANoRagged(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"4",
"--attention-backend",
"flashinfer",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLAMTP(CustomTestCase): class TestFlashinferMLAMTP(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -17,6 +17,7 @@ from sglang.test.test_utils import ( ...@@ -17,6 +17,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server, popen_launch_server,
run_bench_one_batch_server, run_bench_one_batch_server,
) )
...@@ -59,7 +60,7 @@ class TestPPAccuracy(unittest.TestCase): ...@@ -59,7 +60,7 @@ class TestPPAccuracy(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.74) self.assertGreater(metrics["accuracy"], 0.74)
# Wait a little bit so that the memory check happens. # Wait a little bit so that the memory check happens.
time.sleep(5) time.sleep(4)
class TestQwenPPAccuracy(unittest.TestCase): class TestQwenPPAccuracy(unittest.TestCase):
...@@ -97,20 +98,17 @@ class TestQwenPPAccuracy(unittest.TestCase): ...@@ -97,20 +98,17 @@ class TestQwenPPAccuracy(unittest.TestCase):
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
def test_baseline_accuracy(self): @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
def test_pp_consistency(self): def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1) baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2) pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertGreaterEqual(baseline["accuracy"], 0.74)
self.assertGreaterEqual( self.assertGreaterEqual(
pp_metrics["accuracy"], pp_metrics["accuracy"],
baseline["accuracy"] - 0.01, baseline["accuracy"] - 0.02,
msg=( msg=(
f"PP accuracy dropped more than 1% compared to baseline. " f"PP accuracy dropped more than 1% compared to baseline. "
f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}"
...@@ -155,20 +153,16 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase): ...@@ -155,20 +153,16 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.39)
def test_pp_consistency(self): def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1) baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2) pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertGreaterEqual(baseline["accuracy"], 0.38)
self.assertGreaterEqual( self.assertGreaterEqual(
pp_metrics["accuracy"], pp_metrics["accuracy"],
baseline["accuracy"] - 0.01, baseline["accuracy"] - 0.02,
msg=( msg=(
f"PP accuracy dropped more than 1% compared to baseline. " f"PP accuracy dropped more than 1% compared to baseline. "
f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}"
...@@ -211,20 +205,16 @@ class TestQwenMoePPAccuracy(unittest.TestCase): ...@@ -211,20 +205,16 @@ class TestQwenMoePPAccuracy(unittest.TestCase):
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
def test_pp_consistency(self): def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1) baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2) pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertGreaterEqual(baseline["accuracy"], 0.74)
self.assertGreaterEqual( self.assertGreaterEqual(
pp_metrics["accuracy"], pp_metrics["accuracy"],
baseline["accuracy"] - 0.01, baseline["accuracy"] - 0.02,
msg=( msg=(
f"PP accuracy dropped more than 1% compared to baseline. " f"PP accuracy dropped more than 1% compared to baseline. "
f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment