Unverified Commit a2e0424a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix memory leak for chunked prefill 2 (#1858)


Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent 8ce202a4
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 python3 run_suite.py --suite minimal --range-begin 0 --range-end 4
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17 python3 run_suite.py --suite minimal --range-begin 4 --range-end 14
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -84,7 +84,7 @@ jobs: ...@@ -84,7 +84,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 17 --range-end 20 python3 run_suite.py --suite minimal --range-begin 14 --range-end 20
unit-test-backend-part-4: unit-test-backend-part-4:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
# Guide on Hyperparameter Tuning # Guide on Hyperparameter Tuning
## Achieving Peak Throughput ## Achieving Peak Throughput
Achieving a large batch size is the most important thing for attaining high throughput. Achieving a large batch size is the most important thing for attaining high throughput.
When the server is running at full load, look for the following in the log: When the server is running at full load, look for the following in the log:
......
...@@ -221,7 +221,7 @@ class Req: ...@@ -221,7 +221,7 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
self.is_inflight_req = 0 self.is_being_chunked = 0
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = False
...@@ -888,7 +888,7 @@ class ScheduleBatch: ...@@ -888,7 +888,7 @@ class ScheduleBatch:
def filter_batch( def filter_batch(
self, self,
current_inflight_req: Optional[Req] = None, being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None, keep_indices: Optional[List[int]] = None,
): ):
if keep_indices is None: if keep_indices is None:
...@@ -896,7 +896,7 @@ class ScheduleBatch: ...@@ -896,7 +896,7 @@ class ScheduleBatch:
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req and self.reqs[i] is not being_chunked_req
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
......
...@@ -231,7 +231,7 @@ class Scheduler: ...@@ -231,7 +231,7 @@ class Scheduler:
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None self.being_chunked_req = None
self.is_mixed_chunk = ( self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
) )
...@@ -551,13 +551,13 @@ class Scheduler: ...@@ -551,13 +551,13 @@ class Scheduler:
and not self.last_batch.forward_mode.is_decode() and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.current_inflight_req: if self.being_chunked_req:
self.last_batch.filter_batch( self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req being_chunked_req=self.being_chunked_req
) )
self.tree_cache.cache_unfinished_req(self.current_inflight_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx. # Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False self.batch_is_full = False
if not self.last_batch.is_empty(): if not self.last_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
...@@ -588,7 +588,7 @@ class Scheduler: ...@@ -588,7 +588,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.batch_is_full or len(self.waiting_queue) == 0 self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None: ) and self.being_chunked_req is None:
return None return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0 running_bs = len(self.running_batch.reqs) if self.running_batch else 0
...@@ -611,13 +611,11 @@ class Scheduler: ...@@ -611,13 +611,11 @@ class Scheduler:
num_mixed_running, num_mixed_running,
) )
has_inflight = self.current_inflight_req is not None has_inflight = self.being_chunked_req is not None
if has_inflight: if has_inflight:
self.current_inflight_req.init_next_round_input( self.being_chunked_req.init_next_round_input()
None if prefix_computed else self.tree_cache self.being_chunked_req = adder.add_inflight_req(
) self.being_chunked_req
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
) )
if self.lora_paths: if self.lora_paths:
...@@ -661,11 +659,11 @@ class Scheduler: ...@@ -661,11 +659,11 @@ class Scheduler:
] ]
if adder.new_inflight_req is not None: if adder.new_inflight_req is not None:
assert self.current_inflight_req is None assert self.being_chunked_req is None
self.current_inflight_req = adder.new_inflight_req self.being_chunked_req = adder.new_inflight_req
if self.current_inflight_req: if self.being_chunked_req:
self.current_inflight_req.is_inflight_req += 1 self.being_chunked_req.is_being_chunked += 1
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -833,8 +831,8 @@ class Scheduler: ...@@ -833,8 +831,8 @@ class Scheduler:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0: if req.is_being_chunked > 0:
req.is_inflight_req -= 1 req.is_being_chunked -= 1
else: else:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
...@@ -860,8 +858,8 @@ class Scheduler: ...@@ -860,8 +858,8 @@ class Scheduler:
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i] req.embedding = embeddings[i]
if req.is_inflight_req > 0: if req.is_being_chunked > 0:
req.is_inflight_req -= 1 req.is_being_chunked -= 1
else: else:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
# dummy output token for embedding models # dummy output token for embedding models
......
""" # Kill all SGLang processes and free the GPU memory.
Kill all SGLang processes and free the GPU memory.
"""
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
...@@ -19,6 +19,7 @@ suites = { ...@@ -19,6 +19,7 @@ suites = {
"test_openai_server.py", "test_openai_server.py",
"test_overlap_schedule.py", "test_overlap_schedule.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py", "test_retract_decode.py",
"test_server_args.py", "test_server_args.py",
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
......
import os
import random
import unittest
import requests
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
kill_child_process,
popen_launch_server,
)
def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
num1 -= num_branch
random.shuffle(nodes)
return nodes
def run_test(base_url, nodes):
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}
res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200
class TestRadixCacheFCFS(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"fcfs",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_radix_attention(self):
nodes = gen_radix_tree()
run_test(self.base_url, nodes)
class TestRadixCacheLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"lpm",
],
)
if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true"
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment