Unverified Commit b2ccf36d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix memory leak during abort (#2238)

parent d4fc1a70
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 python3 run_suite.py --suite minimal --range-begin 6 --range-end 15
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -84,7 +84,7 @@ jobs: ...@@ -84,7 +84,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 14 --range-end 23 python3 run_suite.py --suite minimal --range-begin 15 --range-end 24
unit-test-backend-part-4: unit-test-backend-part-4:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -101,7 +101,7 @@ jobs: ...@@ -101,7 +101,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 23 python3 run_suite.py --suite minimal --range-begin 24
unit-test-backend-2-gpu-part-1: unit-test-backend-2-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
...@@ -231,6 +231,7 @@ class Req: ...@@ -231,6 +231,7 @@ class Req:
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
self.stream = False self.stream = False
self.to_abort = False
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
...@@ -368,6 +369,10 @@ class Req: ...@@ -368,6 +369,10 @@ class Req:
if self.finished(): if self.finished():
return return
if self.to_abort:
self.finished_reason = FINISH_ABORT()
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens: if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH( self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens length=self.sampling_params.max_new_tokens
......
...@@ -579,6 +579,8 @@ class Scheduler: ...@@ -579,6 +579,8 @@ class Scheduler:
"Image request length is longer than the KV cache pool size or " "Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds" "the max context length aborting because you cannot truncate the image embeds"
) )
req.image_inputs = None
req.origin_input_ids = [0]
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
...@@ -1350,13 +1352,15 @@ class Scheduler: ...@@ -1350,13 +1352,15 @@ class Scheduler:
if to_del is not None: if to_del is not None:
del self.waiting_queue[to_del] del self.waiting_queue[to_del]
logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch # Delete requests in the running batch
if self.running_batch: if self.running_batch:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished(): if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT() logger.debug(f"Abort running request. {req.rid=}")
self.tree_cache.cache_finished_req(req) req.to_abort = True
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
......
...@@ -677,8 +677,14 @@ def run_and_check_memory_leak( ...@@ -677,8 +677,14 @@ def run_and_check_memory_leak(
enable_mixed_chunk, enable_mixed_chunk,
disable_overlap, disable_overlap,
chunked_prefill_size, chunked_prefill_size,
assert_has_abort,
): ):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] other_args = [
"--chunked-prefill-size",
str(chunked_prefill_size),
"--log-level",
"debug",
]
if disable_radix_cache: if disable_radix_cache:
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
if enable_mixed_chunk: if enable_mixed_chunk:
...@@ -723,14 +729,19 @@ def run_and_check_memory_leak( ...@@ -723,14 +729,19 @@ def run_and_check_memory_leak(
# Assert success # Assert success
has_new_server = False has_new_server = False
has_leak = False has_leak = False
has_abort = False
for line in output_lines: for line in output_lines:
if "The server is fired" in line: if "The server is fired" in line:
has_new_server = True has_new_server = True
if "leak" in line: if "leak" in line:
has_leak = True has_leak = True
if "Abort" in line:
has_abort = True
assert has_new_server assert has_new_server
assert not has_leak assert not has_leak
if assert_has_abort:
assert has_abort
def run_mmlu_test( def run_mmlu_test(
...@@ -761,6 +772,7 @@ def run_mmlu_test( ...@@ -761,6 +772,7 @@ def run_mmlu_test(
enable_mixed_chunk, enable_mixed_chunk,
disable_overlap, disable_overlap,
chunked_prefill_size, chunked_prefill_size,
assert_has_abort=False,
) )
...@@ -800,4 +812,5 @@ def run_mulit_request_test( ...@@ -800,4 +812,5 @@ def run_mulit_request_test(
enable_mixed_chunk, enable_mixed_chunk,
enable_overlap, enable_overlap,
chunked_prefill_size, chunked_prefill_size,
assert_has_abort=False,
) )
...@@ -10,6 +10,7 @@ suites = { ...@@ -10,6 +10,7 @@ suites = {
"models/test_lora.py", "models/test_lora.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
"test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_double_sparsity.py", "test_double_sparsity.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
......
import multiprocessing
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
import requests
from sglang.test.test_utils import run_and_check_memory_leak
class TestAbort(unittest.TestCase):
def workload_func(self, base_url, model):
def process_func():
def run_one(_):
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2048,
},
},
)
ret = response.json()
with ThreadPoolExecutor(16) as executor:
list(executor.map(run_one, list(range(16))))
p = multiprocessing.Process(target=process_func)
p.start()
time.sleep(0.5)
p.terminate()
time.sleep(10)
def test_memory_leak(self):
run_and_check_memory_leak(
self.workload_func,
disable_radix_cache=False,
enable_mixed_chunk=False,
disable_overlap=False,
chunked_prefill_size=8192,
assert_has_abort=True,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment