Unverified Commit 04b262cd authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] Fix major performance bug in certain cases (#1563)


Co-authored-by: default avatarhnyls2002 <hnyls2002@gmail.com>
parent 2432ad40
...@@ -130,6 +130,12 @@ jobs: ...@@ -130,6 +130,12 @@ jobs:
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
- name: Benchmark Offline Throughput (Non-streaming, small batch size)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
performance-test-1-gpu-part-2: performance-test-1-gpu-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
......
...@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id) tokenizer = get_tokenizer(tokenizer_id)
if args.dataset_name == "sharegpt": if args.dataset_name == "sharegpt":
assert args.random_input_len is None and args.random_output_len is None
input_requests = sample_sharegpt_requests( input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace):
fixed_output_len=args.sharegpt_output_len, fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "random": elif args.dataset_name == "random":
assert args.random_input_len is not None and args.random_output_len is not None
input_requests = sample_random_requests( input_requests = sample_random_requests(
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
...@@ -964,13 +966,11 @@ if __name__ == "__main__": ...@@ -964,13 +966,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.", help="Number of input tokens per request, used only for random dataset.",
) )
parser.add_argument( parser.add_argument(
"--random-output-len", "--random-output-len",
type=int, type=int,
default=128,
help="Number of output tokens per request, used only for random dataset.", help="Number of output tokens per request, used only for random dataset.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -222,7 +222,7 @@ class Scheduler: ...@@ -222,7 +222,7 @@ class Scheduler:
) )
self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False self.batch_is_full = False
def event_loop(self): def event_loop(self):
while True: while True:
...@@ -261,12 +261,10 @@ class Scheduler: ...@@ -261,12 +261,10 @@ class Scheduler:
for recv_req in recv_reqs: for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput): if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance( elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput) recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
): ):
self.handle_embedding_request(recv_req) self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
...@@ -279,11 +277,12 @@ class Scheduler: ...@@ -279,11 +277,12 @@ class Scheduler:
@torch.inference_mode() @torch.inference_mode()
def forward_step(self): def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None: if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
new_batch = None new_batch = None
else: else:
new_batch = self.get_new_prefill_batch() new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False
if new_batch is not None: if new_batch is not None:
# Run a new prefill batch # Run a new prefill batch
...@@ -447,6 +446,7 @@ class Scheduler: ...@@ -447,6 +446,7 @@ class Scheduler:
len(self.running_batch.reqs) if self.running_batch is not None else 0 len(self.running_batch.reqs) if self.running_batch is not None else 0
) )
if running_bs >= self.max_running_requests: if running_bs >= self.max_running_requests:
self.batch_is_full = True
return None return None
# Get priority queue # Get priority queue
...@@ -490,9 +490,11 @@ class Scheduler: ...@@ -490,9 +490,11 @@ class Scheduler:
) )
> self.max_loras_per_batch > self.max_loras_per_batch
): ):
self.batch_is_full = True
break break
if adder.no_remaining_tokens(): if adder.no_remaining_tokens():
self.batch_is_full = True
break break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req) res = adder.add_one_req(req)
...@@ -500,6 +502,7 @@ class Scheduler: ...@@ -500,6 +502,7 @@ class Scheduler:
not res not res
or running_bs + len(adder.can_run_list) >= self.max_running_requests or running_bs + len(adder.can_run_list) >= self.max_running_requests
): ):
self.batch_is_full = True
break break
can_run_list = adder.can_run_list can_run_list = adder.can_run_list
...@@ -810,9 +813,6 @@ class Scheduler: ...@@ -810,9 +813,6 @@ class Scheduler:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if not has_finished:
self.do_not_get_new_batch = True
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: ScheduleBatch): def handle_finished_requests(self, batch: ScheduleBatch):
...@@ -833,6 +833,8 @@ class Scheduler: ...@@ -833,6 +833,8 @@ class Scheduler:
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req: if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i) unfinished_indices.append(i)
else:
self.batch_is_full = False
if req.finished() or ( if req.finished() or (
req.stream req.stream
......
...@@ -514,7 +514,16 @@ def get_similarities(vec1, vec2): ...@@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
def run_bench_serving(model, num_prompts, request_rate, other_server_args): def run_bench_serving(
model,
num_prompts,
request_rate,
other_server_args,
dataset_name="random",
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
):
# Launch the server # Launch the server
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server( process = popen_launch_server(
...@@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args): ...@@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
base_url=base_url, base_url=base_url,
host=None, host=None,
port=None, port=None,
dataset_name="random", dataset_name=dataset_name,
dataset_path="", dataset_path="",
model=None, model=None,
tokenizer=None, tokenizer=None,
num_prompts=num_prompts, num_prompts=num_prompts,
sharegpt_output_len=None, sharegpt_output_len=None,
random_input_len=4096, random_input_len=random_input_len,
random_output_len=2048, random_output_len=random_output_len,
random_range_ratio=0.0, random_range_ratio=0.0,
request_rate=request_rate, request_rate=request_rate,
multi=None, multi=None,
seed=0, seed=0,
output_file=None, output_file=None,
disable_tqdm=False, disable_tqdm=False,
disable_stream=False, disable_stream=disable_stream,
disable_ignore_eos=False, disable_ignore_eos=False,
extra_request_body=None, extra_request_body=None,
) )
......
...@@ -20,7 +20,22 @@ class TestBenchServing(unittest.TestCase): ...@@ -20,7 +20,22 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
assert res["output_throughput"] > 2600 assert res["output_throughput"] > 2830
def test_offline_throughput_non_stream_small_batch_size(self):
res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=200,
request_rate=float("inf"),
dataset_name="sharegpt",
random_input_len=None,
random_output_len=None,
disable_stream=True,
other_server_args=["--max-running-requests", "10"],
)
if is_in_ci():
assert res["output_throughput"] > 1000
def test_offline_throughput_without_radix_cache(self): def test_offline_throughput_without_radix_cache(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -31,7 +46,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -31,7 +46,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
assert res["output_throughput"] > 2800 assert res["output_throughput"] > 2880
def test_offline_throughput_without_chunked_prefill(self): def test_offline_throughput_without_chunked_prefill(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -58,7 +73,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -58,7 +73,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
assert res["output_throughput"] > 2600 assert res["output_throughput"] > 2930
def test_offline_throughput_default_fp8(self): def test_offline_throughput_default_fp8(self):
res = run_bench_serving( res = run_bench_serving(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment