"vscode:/vscode.git/clone" did not exist on "b1fbef544c993288447181a6c8d8c68d89387ebe"
Unverified Commit 81d27c8e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)

parent 4d4cdb3f
...@@ -97,7 +97,7 @@ from sglang.srt.utils import ( ...@@ -97,7 +97,7 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -422,6 +422,34 @@ class Scheduler: ...@@ -422,6 +422,34 @@ class Scheduler:
}, },
) )
self._dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReq, self.flush_cache_wrapped),
(AbortReq, self.abort_request),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ProfileReq, self.profile),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(
ReleaseMemoryOccupationReqInput,
lambda _: self.release_memory_occupation(),
),
(
ResumeMemoryOccupationReqInput,
lambda _: self.resume_memory_occupation(),
),
]
)
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long.""" """A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
...@@ -563,57 +591,9 @@ class Scheduler: ...@@ -563,57 +591,9 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs: for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput): output = self._dispatcher(recv_req)
self.handle_generate_request(recv_req) if output is not None:
elif isinstance(recv_req, TokenizedEmbeddingReqInput): self.send_to_tokenizer.send_pyobj(output)
self.handle_embedding_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
success, message = self.init_weights_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitWeightsUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
success, message = self.update_weights_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
success, message = self.update_weights_from_tensor(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromTensorReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
self.release_memory_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
self.resume_memory_occupation()
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
def handle_generate_request( def handle_generate_request(
self, self,
...@@ -1545,6 +1525,9 @@ class Scheduler: ...@@ -1545,6 +1525,9 @@ class Scheduler:
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache()
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache.""" """Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and (
...@@ -1597,12 +1580,12 @@ class Scheduler: ...@@ -1597,12 +1580,12 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
else: else:
logger.error(message) logger.error(message)
return success, message return UpdateWeightFromDiskReqOutput(success, message)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group.""" """Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req) success, message = self.tp_worker.init_weights_update_group(recv_req)
return success, message return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed( def update_weights_from_distributed(
self, self,
...@@ -1615,7 +1598,7 @@ class Scheduler: ...@@ -1615,7 +1598,7 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
else: else:
logger.error(message) logger.error(message)
return success, message return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors.""" """Update the online model parameter from tensors."""
...@@ -1626,11 +1609,11 @@ class Scheduler: ...@@ -1626,11 +1609,11 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
else: else:
logger.error(message) logger.error(message)
return success, message return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self): def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state( self.stashed_model_static_state = _export_static_state(
...@@ -1638,6 +1621,7 @@ class Scheduler: ...@@ -1638,6 +1621,7 @@ class Scheduler:
) )
self.memory_saver_adapter.pause() self.memory_saver_adapter.pause()
self.flush_cache() self.flush_cache()
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self): def resume_memory_occupation(self):
self.memory_saver_adapter.resume() self.memory_saver_adapter.resume()
...@@ -1645,6 +1629,13 @@ class Scheduler: ...@@ -1645,6 +1629,13 @@ class Scheduler:
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
) )
del self.stashed_model_static_state del self.stashed_model_static_state
return ResumeMemoryOccupationReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
def start_profile(self) -> None: def start_profile(self) -> None:
if self.profiler is None: if self.profiler is None:
...@@ -1660,20 +1651,20 @@ class Scheduler: ...@@ -1660,20 +1651,20 @@ class Scheduler:
) )
logger.info("Profiler is done") logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]: def open_session(self, recv_req: OpenSessionReqInput):
# handle error # handle error
session_id = recv_req.session_id session_id = recv_req.session_id
if session_id in self.sessions: if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.") logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False return OpenSessionReqOutput(session_id, False)
elif session_id is None: elif session_id is None:
logger.warning(f"session id is None, cannot open.") logger.warning(f"session id is None, cannot open.")
return session_id, False return OpenSessionReqOutput(session_id, False)
else: else:
self.sessions[session_id] = Session( self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id recv_req.capacity_of_str_len, session_id
) )
return session_id, True return OpenSessionReqOutput(session_id, True)
def close_session(self, recv_req: CloseSessionReqInput): def close_session(self, recv_req: CloseSessionReqInput):
# handle error # handle error
......
...@@ -80,7 +80,7 @@ from sglang.srt.utils import ( ...@@ -80,7 +80,7 @@ from sglang.srt.utils import (
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
from sglang.utils import get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -221,6 +221,43 @@ class TokenizerManager: ...@@ -221,6 +221,43 @@ class TokenizerManager:
}, },
) )
self._dispatcher = TypeBasedDispatcher(
[
(BatchStrOut, self._handle_batch_output),
(BatchEmbeddingOut, self._handle_batch_output),
(BatchTokenIDOut, self._handle_batch_output),
(OpenSessionReqOutput, self._handle_open_session_req_output),
(
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
),
(
UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv,
),
(
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
),
(
ReleaseMemoryOccupationReqOutput,
self.release_memory_occupation_communicator.handle_recv,
),
(
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
]
)
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -712,110 +749,64 @@ class TokenizerManager: ...@@ -712,110 +749,64 @@ class TokenizerManager:
"""The event loop that handles requests""" """The event loop that handles requests"""
while True: while True:
recv_obj: Union[ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
BatchStrOut, self._dispatcher(recv_obj)
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
if getattr(state.obj, "return_logprob", False): def _handle_batch_output(
self.convert_logprob_style( self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
meta_info, ):
state.obj.top_logprobs_num, for i, rid in enumerate(recv_obj.rids):
state.obj.return_text_in_logprobs, state = self.rid_to_state.get(rid, None)
recv_obj, if state is None:
i, continue
)
meta_info = {
if not isinstance(recv_obj, BatchEmbeddingOut): "id": rid,
meta_info.update( "finish_reason": recv_obj.finished_reasons[i],
{ "prompt_tokens": recv_obj.prompt_tokens[i],
"completion_tokens": recv_obj.completion_tokens[i], }
"cached_tokens": recv_obj.cached_tokens[i],
} if getattr(state.obj, "return_logprob", False):
) self.convert_logprob_style(
meta_info,
if isinstance(recv_obj, BatchStrOut): state.obj.top_logprobs_num,
out_dict = { state.obj.return_text_in_logprobs,
"text": recv_obj.output_strs[i], recv_obj,
"meta_info": meta_info, i,
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": meta_info,
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if (
self.dump_requests_folder
and state.finished
and state.obj.log_metrics
):
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
) )
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1: if not isinstance(recv_obj, BatchEmbeddingOut):
self.model_update_result.set_result(recv_obj) meta_info.update(
else: # self.server_args.dp_size > 1 {
self.model_update_tmp.append(recv_obj) "completion_tokens": recv_obj.completion_tokens[i],
# set future if the all results are recevied "cached_tokens": recv_obj.cached_tokens[i],
if len(self.model_update_tmp) == self.server_args.dp_size: }
self.model_update_result.set_result(self.model_update_tmp) )
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert ( if isinstance(recv_obj, BatchStrOut):
self.server_args.dp_size == 1 out_dict = {
), "dp_size must be 1 for init parameter update group" "text": recv_obj.output_strs[i],
self.init_weights_update_group_communicator.handle_recv(recv_obj) "meta_info": meta_info,
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): }
assert ( elif isinstance(recv_obj, BatchTokenIDOut):
self.server_args.dp_size == 1 out_dict = {
), "dp_size must be 1 for update weights from distributed" "token_ids": recv_obj.output_ids[i],
self.update_weights_from_distributed_communicator.handle_recv(recv_obj) "meta_info": meta_info,
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput): }
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
self.release_memory_occupation_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
else: else:
raise ValueError(f"Invalid object: {recv_obj=}") assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict)
def convert_logprob_style( def convert_logprob_style(
self, self,
...@@ -943,6 +934,20 @@ class TokenizerManager: ...@@ -943,6 +934,20 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it # Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
)
def _handle_update_weights_from_disk_req_output(self, recv_obj):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
async def print_exception_wrapper(func): async def print_exception_wrapper(func):
""" """
......
...@@ -15,7 +15,7 @@ import urllib.request ...@@ -15,7 +15,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from json import dumps from json import dumps
from typing import Optional, Union from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import requests import requests
...@@ -363,3 +363,14 @@ def terminate_process(process): ...@@ -363,3 +363,14 @@ def terminate_process(process):
def print_highlight(html_content: str): def print_highlight(html_content: str):
html_content = str(html_content).replace("\n", "<br>") html_content = str(html_content).replace("\n", "<br>")
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>")) display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping
def __call__(self, obj: Any):
for ty, fn in self._mapping:
if isinstance(obj, ty):
return fn(obj)
raise ValueError(f"Invalid object: {obj}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment