"examples/controlnet/train_controlnet.py" did not exist on "eadf0e2555cfa19b033e02de53553f71ac33536f"
Unverified Commit d3c275b1 authored by Albert's avatar Albert Committed by GitHub
Browse files

Support updating weights at once by stopping all requests (#6698)


Signed-off-by: default avatarTianyu Zhou <albert.zty@antgroup.com>
Co-authored-by: default avatarZilin Zhu <zhuzilinallen@gmail.com>
parent b044400d
...@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): ...@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
async def abort_request(obj: AbortReq, request: Request): async def abort_request(obj: AbortReq, request: Request):
"""Abort a request.""" """Abort a request."""
try: try:
_global_state.tokenizer_manager.abort_request(rid=obj.rid) _global_state.tokenizer_manager.abort_request(
rid=obj.rid, abort_all=obj.abort_all
)
return Response(status_code=200) return Response(status_code=200)
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
......
...@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None hidden_states: Optional[object] = None
...@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): ...@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[ finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"] Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
] = None ] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
......
...@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput: ...@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
model_path: str model_path: str
# The format to load the weights # The format to load the weights
load_format: Optional[str] = None load_format: Optional[str] = None
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass @dataclass
...@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput: ...@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
group_name: str = "weight_update_group" group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights # Whether to flush the cache after updating weights
flush_cache: bool = True flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass @dataclass
...@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput: ...@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
load_format: Optional[str] = None load_format: Optional[str] = None
# Whether to flush the cache after updating weights # Whether to flush the cache after updating weights
flush_cache: bool = True flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass @dataclass
...@@ -858,7 +864,9 @@ class SlowDownReqOutput: ...@@ -858,7 +864,9 @@ class SlowDownReqOutput:
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id # The request id
rid: str rid: str = ""
# Whether to abort all requests
abort_all: bool = False
@dataclass @dataclass
......
...@@ -2211,7 +2211,7 @@ class Scheduler( ...@@ -2211,7 +2211,7 @@ class Scheduler(
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = [] to_del = []
for i, req in enumerate(self.waiting_queue): for i, req in enumerate(self.waiting_queue):
if req.rid.startswith(recv_req.rid): if recv_req.abort_all or req.rid.startswith(recv_req.rid):
to_del.append(i) to_del.append(i)
# Sort in reverse order to avoid index issues when deleting # Sort in reverse order to avoid index issues when deleting
...@@ -2228,7 +2228,7 @@ class Scheduler( ...@@ -2228,7 +2228,7 @@ class Scheduler(
# Abort method 2: call `set_finish_with_abort` # Abort method 2: call `set_finish_with_abort`
# The request will still run one prefill forward pass. # The request will still run one prefill forward pass.
# In this case, we change the input_ids to be only one token to make this prefill cheap. # In this case, we change the input_ids to be only one token to make this prefill cheap.
if req.rid.startswith(recv_req.rid): if recv_req.abort_all or req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}") logger.debug(f"Abort grammar queue request. {req.rid=}")
if req.grammar: if req.grammar:
req.grammar.cancel() req.grammar.cancel()
...@@ -2241,7 +2241,9 @@ class Scheduler( ...@@ -2241,7 +2241,9 @@ class Scheduler(
reqs = self.running_batch.reqs + self.cur_batch.reqs reqs = self.running_batch.reqs + self.cur_batch.reqs
for req in reqs: for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished(): if not req.finished() and (
recv_req.abort_all or req.rid.startswith(recv_req.rid)
):
# Abort method 3: set `to_abort=True` # Abort method 3: set `to_abort=True`
# The request will still run one decode forward pass. # The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation. # Then we reuse all existing code to clean up the KV cache allocation.
......
...@@ -846,10 +846,10 @@ class TokenizerManager: ...@@ -846,10 +846,10 @@ class TokenizerManager:
async def flush_cache(self) -> FlushCacheReqOutput: async def flush_cache(self) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
def abort_request(self, rid: str): def abort_request(self, rid: str = "", abort_all: bool = False):
if rid not in self.rid_to_state: if not abort_all and rid not in self.rid_to_state:
return return
req = AbortReq(rid) req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics: if self.enable_metrics:
...@@ -914,6 +914,9 @@ class TokenizerManager: ...@@ -914,6 +914,9 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format) logger.info("Start update_weights. Load format=%s", obj.load_format)
if obj.abort_all_requests:
self.abort_request(abort_all=True)
if True: # Keep this redundant check to simplify some internal code sync if True: # Keep this redundant check to simplify some internal code sync
# Hold the lock if it is not async. This means that weight sync # Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
...@@ -969,6 +972,9 @@ class TokenizerManager: ...@@ -969,6 +972,9 @@ class TokenizerManager:
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync # This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
...@@ -985,6 +991,9 @@ class TokenizerManager: ...@@ -985,6 +991,9 @@ class TokenizerManager:
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor" ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync # This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
...@@ -1619,7 +1628,23 @@ class TokenizerManager: ...@@ -1619,7 +1628,23 @@ class TokenizerManager:
self.crash_dump_request_list.popleft() self.crash_dump_request_list.popleft()
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid, None) state = self.rid_to_state[recv_obj.rid]
state.finished = True
state.out_list.append(
{
"text": "",
"meta_info": {
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
}
)
state.event.set()
def _handle_open_session_req_output(self, recv_obj): def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
......
import json
import multiprocessing import multiprocessing
import time import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor, as_completed
import requests import requests
from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
run_and_check_memory_leak,
)
class TestAbort(CustomTestCase): class TestAbort(CustomTestCase):
...@@ -50,5 +59,56 @@ class TestAbort(CustomTestCase): ...@@ -50,5 +59,56 @@ class TestAbort(CustomTestCase):
) )
class TestAbortAll(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--max-running-requests", 8],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def _run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16000,
"ignore_eos": True,
},
},
)
return response.json()
def test_abort_all(self):
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
# ensure the decode has been started
time.sleep(2)
requests.post(
self.base_url + "/abort_request",
json={
"abort_all": True,
},
)
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import json import json
import random import random
import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests import requests
...@@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase): ...@@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase):
self.assertEqual(origin_response[:32], updated_response[:32]) self.assertEqual(origin_response[:32], updated_response[:32])
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--max-running-requests", 8],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens=32):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()
def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(self, model_path, abort_all_requests=False):
response = requests.post(
self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
"abort_all_requests": abort_all_requests,
},
)
ret = response.json()
print(json.dumps(ret))
return ret
def test_update_weights_abort_all_requests(self):
origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
]
# ensure the decode has been started
time.sleep(2)
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
self.assertTrue(ret["success"])
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
updated_model_path = self.get_model_info()
print(f"[Server Mode] updated_model_path: {updated_model_path}")
self.assertEqual(updated_model_path, new_model_path)
self.assertNotEqual(updated_model_path, origin_model_path)
############################################################################### ###############################################################################
# Parameterized Tests for update_weights_from_disk # Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci: # Test coverage is determined based on the value of is_in_ci:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment