Unverified Commit f2388f6b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Rename TokenizerManager to StdOrchestrator" (#3828)

parent c9745ee0
...@@ -426,7 +426,7 @@ ...@@ -426,7 +426,7 @@
"from sglang.srt.managers.io_struct import Tool, Function\n", "from sglang.srt.managers.io_struct import Tool, Function\n",
"\n", "\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"tokenizer = llm.orchestrator.tokenizer\n", "tokenizer = llm.tokenizer_manager.tokenizer\n",
"input_ids = tokenizer.apply_chat_template(\n", "input_ids = tokenizer.apply_chat_template(\n",
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
")\n", ")\n",
......
...@@ -48,8 +48,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -48,8 +48,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -74,12 +74,12 @@ class Engine: ...@@ -74,12 +74,12 @@ class Engine:
The entry point to the inference engine. The entry point to the inference engine.
- The engine consists of three components: - The engine consists of three components:
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note: Note:
1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
""" """
...@@ -102,8 +102,10 @@ class Engine: ...@@ -102,8 +102,10 @@ class Engine:
atexit.register(self.shutdown) atexit.register(self.shutdown)
# Launch subprocesses # Launch subprocesses
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) tokenizer_manager, scheduler_info = _launch_subprocesses(
self.orchestrator = orchestrator server_args=server_args
)
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info self.scheduler_info = scheduler_info
def generate( def generate(
...@@ -145,7 +147,7 @@ class Engine: ...@@ -145,7 +147,7 @@ class Engine:
stream=stream, stream=stream,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
generator = self.orchestrator.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
if stream: if stream:
...@@ -195,7 +197,7 @@ class Engine: ...@@ -195,7 +197,7 @@ class Engine:
stream=stream, stream=stream,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
) )
generator = self.orchestrator.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
if stream is True: if stream is True:
return generator return generator
...@@ -213,7 +215,7 @@ class Engine: ...@@ -213,7 +215,7 @@ class Engine:
obj = EmbeddingReqInput(text=prompt) obj = EmbeddingReqInput(text=prompt)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
generator = self.orchestrator.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__()) ret = loop.run_until_complete(generator.__anext__())
return ret return ret
...@@ -222,14 +224,14 @@ class Engine: ...@@ -222,14 +224,14 @@ class Engine:
kill_process_tree(os.getpid(), include_parent=False) kill_process_tree(os.getpid(), include_parent=False)
def start_profile(self): def start_profile(self):
self.orchestrator.start_profile() self.tokenizer_manager.start_profile()
def stop_profile(self): def stop_profile(self):
self.orchestrator.stop_profile() self.tokenizer_manager.stop_profile()
def get_server_info(self): def get_server_info(self):
return { return {
**dataclasses.asdict(self.orchestrator.server_args), # server args **dataclasses.asdict(self.tokenizer_manager.server_args), # server args
**self.scheduler_info, **self.scheduler_info,
"version": __version__, "version": __version__,
} }
...@@ -254,7 +256,7 @@ class Engine: ...@@ -254,7 +256,7 @@ class Engine:
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.orchestrator.init_weights_update_group(obj, None) self.tokenizer_manager.init_weights_update_group(obj, None)
) )
def update_weights_from_distributed(self, name: str, dtype, shape): def update_weights_from_distributed(self, name: str, dtype, shape):
...@@ -266,7 +268,7 @@ class Engine: ...@@ -266,7 +268,7 @@ class Engine:
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.orchestrator.update_weights_from_distributed(obj, None) self.tokenizer_manager.update_weights_from_distributed(obj, None)
) )
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
...@@ -276,21 +278,23 @@ class Engine: ...@@ -276,21 +278,23 @@ class Engine:
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.orchestrator.update_weights_from_tensor(obj, None) self.tokenizer_manager.update_weights_from_tensor(obj, None)
) )
def get_weights_by_name(self, name: str, truncate_size: int = 100): def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name.""" """Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(self.orchestrator.get_weights_by_name(obj, None)) return loop.run_until_complete(
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def release_memory_occupation(self): def release_memory_occupation(self):
"""Release GPU occupation temporarily.""" """Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput() obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.orchestrator.release_memory_occupation(obj, None) self.tokenizer_manager.release_memory_occupation(obj, None)
) )
def resume_memory_occupation(self): def resume_memory_occupation(self):
...@@ -298,7 +302,7 @@ class Engine: ...@@ -298,7 +302,7 @@ class Engine:
obj = ResumeMemoryOccupationReqInput() obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.orchestrator.resume_memory_occupation(obj, None) self.tokenizer_manager.resume_memory_occupation(obj, None)
) )
...@@ -347,9 +351,9 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -347,9 +351,9 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
""" """
Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
""" """
# Configure global environment # Configure global environment
configure_logger(server_args) configure_logger(server_args)
...@@ -432,10 +436,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict ...@@ -432,10 +436,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict
detoken_proc.start() detoken_proc.start()
# Launch tokenizer process # Launch tokenizer process
orchestrator = StdOrchestrator(server_args, port_args) tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template: if server_args.chat_template:
load_chat_template_for_openai_api( load_chat_template_for_openai_api(
orchestrator, server_args.chat_template, server_args.model_path tokenizer_manager, server_args.chat_template, server_args.model_path
) )
# Wait for the model to finish loading # Wait for the model to finish loading
...@@ -459,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict ...@@ -459,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"]) tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"])
return orchestrator, scheduler_info return tokenizer_manager, scheduler_info
...@@ -54,6 +54,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -54,6 +54,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
) )
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
v1_batches, v1_batches,
...@@ -68,7 +69,6 @@ from sglang.srt.openai_api.adapter import ( ...@@ -68,7 +69,6 @@ from sglang.srt.openai_api.adapter import (
v1_retrieve_file_content, v1_retrieve_file_content,
) )
from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
...@@ -97,7 +97,7 @@ app.add_middleware( ...@@ -97,7 +97,7 @@ app.add_middleware(
# Store global states # Store global states
@dataclasses.dataclass @dataclasses.dataclass
class _GlobalState: class _GlobalState:
orchestrator: StdOrchestrator tokenizer_manager: TokenizerManager
scheduler_info: Dict scheduler_info: Dict
...@@ -124,7 +124,7 @@ async def health_generate(request: Request) -> Response: ...@@ -124,7 +124,7 @@ async def health_generate(request: Request) -> Response:
sampling_params = {"max_new_tokens": 1, "temperature": 0.7} sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
if _global_state.orchestrator.is_generation: if _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput( gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False input_ids=[0], sampling_params=sampling_params, log_metrics=False
) )
...@@ -134,7 +134,7 @@ async def health_generate(request: Request) -> Response: ...@@ -134,7 +134,7 @@ async def health_generate(request: Request) -> Response:
) )
try: try:
async for _ in _global_state.orchestrator.generate_request(gri, request): async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break break
return Response(status_code=200) return Response(status_code=200)
except Exception as e: except Exception as e:
...@@ -146,9 +146,9 @@ async def health_generate(request: Request) -> Response: ...@@ -146,9 +146,9 @@ async def health_generate(request: Request) -> Response:
async def get_model_info(): async def get_model_info():
"""Get the model information.""" """Get the model information."""
result = { result = {
"model_path": _global_state.orchestrator.model_path, "model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.orchestrator.server_args.tokenizer_path, "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.orchestrator.is_generation, "is_generation": _global_state.tokenizer_manager.is_generation,
} }
return result return result
...@@ -156,7 +156,7 @@ async def get_model_info(): ...@@ -156,7 +156,7 @@ async def get_model_info():
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info():
return { return {
**dataclasses.asdict(_global_state.orchestrator.server_args), **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info, **_global_state.scheduler_info,
"version": __version__, "version": __version__,
} }
...@@ -170,7 +170,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -170,7 +170,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def stream_results() -> AsyncIterator[bytes]: async def stream_results() -> AsyncIterator[bytes]:
try: try:
async for out in _global_state.orchestrator.generate_request( async for out in _global_state.tokenizer_manager.generate_request(
obj, request obj, request
): ):
yield b"data: " + orjson.dumps( yield b"data: " + orjson.dumps(
...@@ -186,11 +186,11 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -186,11 +186,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return StreamingResponse( return StreamingResponse(
stream_results(), stream_results(),
media_type="text/event-stream", media_type="text/event-stream",
background=_global_state.orchestrator.create_abort_task(obj), background=_global_state.tokenizer_manager.create_abort_task(obj),
) )
else: else:
try: try:
ret = await _global_state.orchestrator.generate_request( ret = await _global_state.tokenizer_manager.generate_request(
obj, request obj, request
).__anext__() ).__anext__()
return ret return ret
...@@ -203,7 +203,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -203,7 +203,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request.""" """Handle an embedding request."""
try: try:
ret = await _global_state.orchestrator.generate_request( ret = await _global_state.tokenizer_manager.generate_request(
obj, request obj, request
).__anext__() ).__anext__()
return ret return ret
...@@ -215,7 +215,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ...@@ -215,7 +215,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
async def classify_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models.""" """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try: try:
ret = await _global_state.orchestrator.generate_request( ret = await _global_state.tokenizer_manager.generate_request(
obj, request obj, request
).__anext__() ).__anext__()
return ret return ret
...@@ -226,7 +226,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -226,7 +226,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
@app.post("/flush_cache") @app.post("/flush_cache")
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
_global_state.orchestrator.flush_cache() _global_state.tokenizer_manager.flush_cache()
return Response( return Response(
content="Cache flushed.\nPlease check backend logs for more details. " content="Cache flushed.\nPlease check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)\n", "(When there are running or waiting requests, the operation will not be performed.)\n",
...@@ -237,7 +237,7 @@ async def flush_cache(): ...@@ -237,7 +237,7 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"]) @app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async(): async def start_profile_async():
"""Start profiling.""" """Start profiling."""
_global_state.orchestrator.start_profile() _global_state.tokenizer_manager.start_profile()
return Response( return Response(
content="Start profiling.\n", content="Start profiling.\n",
status_code=200, status_code=200,
...@@ -247,7 +247,7 @@ async def start_profile_async(): ...@@ -247,7 +247,7 @@ async def start_profile_async():
@app.api_route("/stop_profile", methods=["GET", "POST"]) @app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async(): async def stop_profile_async():
"""Stop profiling.""" """Stop profiling."""
_global_state.orchestrator.stop_profile() _global_state.tokenizer_manager.stop_profile()
return Response( return Response(
content="Stop profiling. This will take some time.\n", content="Stop profiling. This will take some time.\n",
status_code=200, status_code=200,
...@@ -257,7 +257,7 @@ async def stop_profile_async(): ...@@ -257,7 +257,7 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk") @app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server.""" """Update the weights from disk in-place without re-launching the server."""
success, message = await _global_state.orchestrator.update_weights_from_disk( success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
obj, request obj, request
) )
content = {"success": success, "message": message} content = {"success": success, "message": message}
...@@ -278,7 +278,7 @@ async def init_weights_update_group( ...@@ -278,7 +278,7 @@ async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request obj: InitWeightsUpdateGroupReqInput, request: Request
): ):
"""Initialize the parameter update group.""" """Initialize the parameter update group."""
success, message = await _global_state.orchestrator.init_weights_update_group( success, message = await _global_state.tokenizer_manager.init_weights_update_group(
obj, request obj, request
) )
content = {"success": success, "message": message} content = {"success": success, "message": message}
...@@ -293,8 +293,10 @@ async def update_weights_from_distributed( ...@@ -293,8 +293,10 @@ async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request obj: UpdateWeightsFromDistributedReqInput, request: Request
): ):
"""Update model parameter from distributed online.""" """Update model parameter from distributed online."""
success, message = await _global_state.orchestrator.update_weights_from_distributed( success, message = (
obj, request await _global_state.tokenizer_manager.update_weights_from_distributed(
obj, request
)
) )
content = {"success": success, "message": message} content = {"success": success, "message": message}
if success: if success:
...@@ -307,7 +309,7 @@ async def update_weights_from_distributed( ...@@ -307,7 +309,7 @@ async def update_weights_from_distributed(
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name.""" """Get model parameter by name."""
try: try:
ret = await _global_state.orchestrator.get_weights_by_name(obj, request) ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
if ret is None: if ret is None:
return _create_error_response("Get parameter by name failed") return _create_error_response("Get parameter by name failed")
else: else:
...@@ -322,7 +324,7 @@ async def release_memory_occupation( ...@@ -322,7 +324,7 @@ async def release_memory_occupation(
): ):
"""Release GPU occupation temporarily""" """Release GPU occupation temporarily"""
try: try:
await _global_state.orchestrator.release_memory_occupation(obj, request) await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
...@@ -333,7 +335,7 @@ async def resume_memory_occupation( ...@@ -333,7 +335,7 @@ async def resume_memory_occupation(
): ):
"""Resume GPU occupation""" """Resume GPU occupation"""
try: try:
await _global_state.orchestrator.resume_memory_occupation(obj, request) await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
...@@ -342,7 +344,7 @@ async def resume_memory_occupation( ...@@ -342,7 +344,7 @@ async def resume_memory_occupation(
async def open_session(obj: OpenSessionReqInput, request: Request): async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
try: try:
session_id = await _global_state.orchestrator.open_session(obj, request) session_id = await _global_state.tokenizer_manager.open_session(obj, request)
if session_id is None: if session_id is None:
raise Exception( raise Exception(
"Failed to open the session. Check if a session with the same id is still open." "Failed to open the session. Check if a session with the same id is still open."
...@@ -356,7 +358,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): ...@@ -356,7 +358,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
async def close_session(obj: CloseSessionReqInput, request: Request): async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session""" """Close the session"""
try: try:
await _global_state.orchestrator.close_session(obj, request) await _global_state.tokenizer_manager.close_session(obj, request)
return Response(status_code=200) return Response(status_code=200)
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
...@@ -365,7 +367,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ...@@ -365,7 +367,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"]) @app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request): async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session""" """Close the session"""
_global_state.orchestrator.configure_logging(obj) _global_state.tokenizer_manager.configure_logging(obj)
return Response(status_code=200) return Response(status_code=200)
...@@ -396,24 +398,24 @@ async def function_call_request(obj: FunctionCallReqInput, request: Request): ...@@ -396,24 +398,24 @@ async def function_call_request(obj: FunctionCallReqInput, request: Request):
@app.post("/v1/completions") @app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
return await v1_completions(_global_state.orchestrator, raw_request) return await v1_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request): async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.orchestrator, raw_request) return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/embeddings", response_class=ORJSONResponse) @app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request): async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(_global_state.orchestrator, raw_request) response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
return response return response
@app.get("/v1/models", response_class=ORJSONResponse) @app.get("/v1/models", response_class=ORJSONResponse)
def available_models(): def available_models():
"""Show available models.""" """Show available models."""
served_model_names = [_global_state.orchestrator.served_model_name] served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = [] model_cards = []
for served_model_name in served_model_names: for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
...@@ -423,7 +425,7 @@ def available_models(): ...@@ -423,7 +425,7 @@ def available_models():
@app.post("/v1/files") @app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create( return await v1_files_create(
file, purpose, _global_state.orchestrator.server_args.file_storage_pth file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
) )
...@@ -435,13 +437,13 @@ async def delete_file(file_id: str): ...@@ -435,13 +437,13 @@ async def delete_file(file_id: str):
@app.post("/v1/batches") @app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request): async def openai_v1_batches(raw_request: Request):
return await v1_batches(_global_state.orchestrator, raw_request) return await v1_batches(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel") @app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str): async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel # https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(_global_state.orchestrator, batch_id) return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}") @app.get("/v1/batches/{batch_id}")
...@@ -490,18 +492,18 @@ def launch_server( ...@@ -490,18 +492,18 @@ def launch_server(
- HTTP server: A FastAPI server that routes requests to the engine. - HTTP server: A FastAPI server that routes requests to the engine.
- The engine consists of three components: - The engine consists of three components:
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note: Note:
1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
""" """
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
orchestrator=orchestrator, tokenizer_manager=tokenizer_manager,
scheduler_info=scheduler_info, scheduler_info=scheduler_info,
) )
) )
...@@ -521,7 +523,7 @@ def launch_server( ...@@ -521,7 +523,7 @@ def launch_server(
args=( args=(
server_args, server_args,
pipe_finish_writer, pipe_finish_writer,
_global_state.orchestrator.image_token_id, _global_state.tokenizer_manager.image_token_id,
), ),
) )
t.start() t.start()
......
...@@ -241,7 +241,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -241,7 +241,7 @@ class LlavaImageProcessor(BaseImageProcessor):
return pixel_values, image_hash, image.size return pixel_values, image_hash, image.size
except Exception: except Exception:
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image( async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
...@@ -491,7 +491,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): ...@@ -491,7 +491,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return pixel_values, image_hash, image.size, image_grid_thws return pixel_values, image_hash, image.size, image_grid_thws
except Exception: except Exception:
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str]): async def _process_single_image(self, image_data: Union[bytes, str]):
if self.executor is not None: if self.executor is not None:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# ============================================================================== # ==============================================================================
""" """
The definition of objects transfered between different The definition of objects transfered between different
processes (StdOrchestrator, DetokenizerManager, Controller). processes (TokenizerManager, DetokenizerManager, Controller).
""" """
import uuid import uuid
......
...@@ -174,7 +174,7 @@ class Scheduler: ...@@ -174,7 +174,7 @@ class Scheduler:
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the StdOrchestrator # Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""TokenizerManager is a process that tokenizes the text."""
import asyncio import asyncio
import logging import logging
...@@ -65,8 +66,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -65,8 +66,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StdOrchestrator: class TokenizerManager:
"""StdOrchestrator is the primary entrypoint of orchestration.std package""" """TokenizerManager is a process that tokenizes the text."""
def __init__( def __init__(
self, self,
...@@ -438,20 +439,20 @@ async def print_exception_wrapper(func): ...@@ -438,20 +439,20 @@ async def print_exception_wrapper(func):
await func() await func()
except Exception: except Exception:
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"StdOrchestrator hit an exception: {traceback}") logger.error(f"TokenizerManager hit an exception: {traceback}")
kill_process_tree(os.getpid(), include_parent=True) kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1) sys.exit(1)
class SignalHandler: class SignalHandler:
def __init__(self, orchestrator): def __init__(self, tokenizer_manager):
self.orchestrator = orchestrator self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None): def signal_handler(self, signum=None, frame=None):
logger.warning( logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
) )
self.orchestrator.gracefully_exit = True self.tokenizer_manager.gracefully_exit = True
T = TypeVar("T") T = TypeVar("T")
......
...@@ -117,7 +117,7 @@ def create_streaming_error_response( ...@@ -117,7 +117,7 @@ def create_streaming_error_response(
return json_str return json_str
def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_path): def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
global chat_template_name global chat_template_name
logger.info( logger.info(
...@@ -133,7 +133,9 @@ def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_pat ...@@ -133,7 +133,9 @@ def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_pat
if chat_template_arg.endswith(".jinja"): if chat_template_arg.endswith(".jinja"):
with open(chat_template_arg, "r") as f: with open(chat_template_arg, "r") as f:
chat_template = "".join(f.readlines()).strip("\n") chat_template = "".join(f.readlines()).strip("\n")
orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n") tokenizer_manager.tokenizer.chat_template = chat_template.replace(
"\\n", "\n"
)
chat_template_name = None chat_template_name = None
else: else:
assert chat_template_arg.endswith( assert chat_template_arg.endswith(
...@@ -229,7 +231,7 @@ async def v1_delete_file(file_id: str): ...@@ -229,7 +231,7 @@ async def v1_delete_file(file_id: str):
return FileDeleteResponse(id=file_id, deleted=True) return FileDeleteResponse(id=file_id, deleted=True)
async def v1_batches(orchestrator, raw_request: Request): async def v1_batches(tokenizer_manager, raw_request: Request):
try: try:
body = await raw_request.json() body = await raw_request.json()
...@@ -250,7 +252,7 @@ async def v1_batches(orchestrator, raw_request: Request): ...@@ -250,7 +252,7 @@ async def v1_batches(orchestrator, raw_request: Request):
batch_storage[batch_id] = batch_response batch_storage[batch_id] = batch_response
# Start processing the batch asynchronously # Start processing the batch asynchronously
asyncio.create_task(process_batch(orchestrator, batch_id, batch_request)) asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
# Return the initial batch_response # Return the initial batch_response
return batch_response return batch_response
...@@ -261,7 +263,7 @@ async def v1_batches(orchestrator, raw_request: Request): ...@@ -261,7 +263,7 @@ async def v1_batches(orchestrator, raw_request: Request):
return {"error": str(e)} return {"error": str(e)}
async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest): async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
try: try:
# Update the batch status to "in_progress" # Update the batch status to "in_progress"
batch_storage[batch_id].status = "in_progress" batch_storage[batch_id].status = "in_progress"
...@@ -304,7 +306,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest ...@@ -304,7 +306,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
if end_point == "/v1/chat/completions": if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request( adapted_request, request = v1_chat_generate_request(
all_requests, orchestrator, request_ids=request_ids all_requests, tokenizer_manager, request_ids=request_ids
) )
elif end_point == "/v1/completions": elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request( adapted_request, request = v1_generate_request(
...@@ -312,7 +314,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest ...@@ -312,7 +314,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
) )
try: try:
ret = await orchestrator.generate_request(adapted_request).__anext__() ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
if end_point == "/v1/chat/completions": if end_point == "/v1/chat/completions":
...@@ -320,12 +322,12 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest ...@@ -320,12 +322,12 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
request, request,
ret, ret,
to_file=True, to_file=True,
cache_report=orchestrator.server_args.enable_cache_report, cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser, tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
) )
else: else:
responses = v1_generate_response( responses = v1_generate_response(
request, ret, orchestrator, to_file=True request, ret, tokenizer_manager, to_file=True
) )
except Exception as e: except Exception as e:
...@@ -397,7 +399,7 @@ async def v1_retrieve_batch(batch_id: str): ...@@ -397,7 +399,7 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response return batch_response
async def v1_cancel_batch(orchestrator, batch_id: str): async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage # Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id) batch_response = batch_storage.get(batch_id)
if batch_response is None: if batch_response is None:
...@@ -408,7 +410,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str): ...@@ -408,7 +410,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str):
# Start cancelling the batch asynchronously # Start cancelling the batch asynchronously
asyncio.create_task( asyncio.create_task(
cancel_batch( cancel_batch(
orchestrator=orchestrator, tokenizer_manager=tokenizer_manager,
batch_id=batch_id, batch_id=batch_id,
input_file_id=batch_response.input_file_id, input_file_id=batch_response.input_file_id,
) )
...@@ -425,7 +427,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str): ...@@ -425,7 +427,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str):
) )
async def cancel_batch(orchestrator, batch_id: str, input_file_id: str): async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try: try:
# Update the batch status to "cancelling" # Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling" batch_storage[batch_id].status = "cancelling"
...@@ -449,7 +451,7 @@ async def cancel_batch(orchestrator, batch_id: str, input_file_id: str): ...@@ -449,7 +451,7 @@ async def cancel_batch(orchestrator, batch_id: str, input_file_id: str):
# Cancel requests by request_ids # Cancel requests by request_ids
for rid in request_ids: for rid in request_ids:
orchestrator.abort_request(rid=rid) tokenizer_manager.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id] retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled" retrieve_batch.status = "cancelled"
...@@ -577,7 +579,7 @@ def v1_generate_request( ...@@ -577,7 +579,7 @@ def v1_generate_request(
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_generate_response(request, ret, orchestrator, to_file=False): def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
choices = [] choices = []
echo = False echo = False
...@@ -589,13 +591,15 @@ def v1_generate_response(request, ret, orchestrator, to_file=False): ...@@ -589,13 +591,15 @@ def v1_generate_response(request, ret, orchestrator, to_file=False):
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts # for the case of multiple token ids prompts
prompts = [ prompts = [
orchestrator.tokenizer.decode(prompt, skip_special_tokens=True) tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt for prompt in request.prompt
] ]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt # for the case of single token ids prompt
prompts = [ prompts = [
orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True) tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
] ]
else: else:
# for the case of single str prompt # for the case of single str prompt
...@@ -705,7 +709,7 @@ def v1_generate_response(request, ret, orchestrator, to_file=False): ...@@ -705,7 +709,7 @@ def v1_generate_response(request, ret, orchestrator, to_file=False):
return response return response
async def v1_completions(orchestrator, raw_request: Request): async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
all_requests = [CompletionRequest(**request_json)] all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests) adapted_request, request = v1_generate_request(all_requests)
...@@ -718,7 +722,7 @@ async def v1_completions(orchestrator, raw_request: Request): ...@@ -718,7 +722,7 @@ async def v1_completions(orchestrator, raw_request: Request):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
try: try:
async for content in orchestrator.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
index = content.get("index", 0) index = content.get("index", 0)
...@@ -741,14 +745,14 @@ async def v1_completions(orchestrator, raw_request: Request): ...@@ -741,14 +745,14 @@ async def v1_completions(orchestrator, raw_request: Request):
prompts = request.prompt[index // request.n] prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int): elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt # for the case of single token ids prompt
prompts = orchestrator.tokenizer.decode( prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True request.prompt, skip_special_tokens=True
) )
elif isinstance(request.prompt[0], list) and isinstance( elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int request.prompt[0][0], int
): ):
# for the case of multiple token ids prompts # for the case of multiple token ids prompts
prompts = orchestrator.tokenizer.decode( prompts = tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n], request.prompt[index // request.n],
skip_special_tokens=True, skip_special_tokens=True,
) )
...@@ -843,12 +847,12 @@ async def v1_completions(orchestrator, raw_request: Request): ...@@ -843,12 +847,12 @@ async def v1_completions(orchestrator, raw_request: Request):
return StreamingResponse( return StreamingResponse(
generate_stream_resp(), generate_stream_resp(),
media_type="text/event-stream", media_type="text/event-stream",
background=orchestrator.create_abort_task(adapted_request), background=tokenizer_manager.create_abort_task(adapted_request),
) )
# Non-streaming response. # Non-streaming response.
try: try:
ret = await orchestrator.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
).__anext__() ).__anext__()
except ValueError as e: except ValueError as e:
...@@ -857,13 +861,13 @@ async def v1_completions(orchestrator, raw_request: Request): ...@@ -857,13 +861,13 @@ async def v1_completions(orchestrator, raw_request: Request):
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
response = v1_generate_response(request, ret, orchestrator) response = v1_generate_response(request, ret, tokenizer_manager)
return response return response
def v1_chat_generate_request( def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], all_requests: List[ChatCompletionRequest],
orchestrator, tokenizer_manager,
request_ids: List[str] = None, request_ids: List[str] = None,
): ):
input_ids = [] input_ids = []
...@@ -918,7 +922,7 @@ def v1_chat_generate_request( ...@@ -918,7 +922,7 @@ def v1_chat_generate_request(
assistant_prefix = None assistant_prefix = None
try: try:
prompt_ids = orchestrator.tokenizer.apply_chat_template( prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages, openai_compatible_messages,
tokenize=True, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
...@@ -929,7 +933,7 @@ def v1_chat_generate_request( ...@@ -929,7 +933,7 @@ def v1_chat_generate_request(
# has a different tools input format that is not compatiable # has a different tools input format that is not compatiable
# with openAI's apply_chat_template tool_call format, like Mistral. # with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools] tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = orchestrator.tokenizer.apply_chat_template( prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages, openai_compatible_messages,
tokenize=True, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
...@@ -937,8 +941,11 @@ def v1_chat_generate_request( ...@@ -937,8 +941,11 @@ def v1_chat_generate_request(
) )
if assistant_prefix: if assistant_prefix:
encoded = orchestrator.tokenizer.encode(assistant_prefix) encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id: if (
encoded
and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
):
encoded = encoded[1:] encoded = encoded[1:]
prompt_ids += encoded prompt_ids += encoded
stop = request.stop stop = request.stop
...@@ -955,7 +962,7 @@ def v1_chat_generate_request( ...@@ -955,7 +962,7 @@ def v1_chat_generate_request(
stop.append(request.stop) stop.append(request.stop)
else: else:
stop.extend(request.stop) stop.extend(request.stop)
prompt_ids = orchestrator.tokenizer.encode(prompt) prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else: else:
# Use the raw prompt and stop strings if the messages is already a string. # Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages prompt_ids = request.messages
...@@ -1194,10 +1201,10 @@ def v1_chat_generate_response( ...@@ -1194,10 +1201,10 @@ def v1_chat_generate_response(
return response return response
async def v1_chat_completions(orchestrator, raw_request: Request): async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
all_requests = [ChatCompletionRequest(**request_json)] all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, orchestrator) adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
if adapted_request.stream: if adapted_request.stream:
parser_dict = {} parser_dict = {}
...@@ -1209,7 +1216,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request): ...@@ -1209,7 +1216,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
try: try:
async for content in orchestrator.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
index = content.get("index", 0) index = content.get("index", 0)
...@@ -1299,7 +1306,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request): ...@@ -1299,7 +1306,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
if index not in parser_dict: if index not in parser_dict:
parser_dict[index] = FunctionCallParser( parser_dict[index] = FunctionCallParser(
tools=request.tools, tools=request.tools,
tool_call_parser=orchestrator.server_args.tool_call_parser, tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
) )
parser = parser_dict[index] parser = parser_dict[index]
...@@ -1431,12 +1438,12 @@ async def v1_chat_completions(orchestrator, raw_request: Request): ...@@ -1431,12 +1438,12 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
return StreamingResponse( return StreamingResponse(
generate_stream_resp(), generate_stream_resp(),
media_type="text/event-stream", media_type="text/event-stream",
background=orchestrator.create_abort_task(adapted_request), background=tokenizer_manager.create_abort_task(adapted_request),
) )
# Non-streaming response. # Non-streaming response.
try: try:
ret = await orchestrator.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
).__anext__() ).__anext__()
except ValueError as e: except ValueError as e:
...@@ -1447,14 +1454,14 @@ async def v1_chat_completions(orchestrator, raw_request: Request): ...@@ -1447,14 +1454,14 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
response = v1_chat_generate_response( response = v1_chat_generate_response(
request, request,
ret, ret,
cache_report=orchestrator.server_args.enable_cache_report, cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser, tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
) )
return response return response
def v1_embedding_request(all_requests, orchestrator): def v1_embedding_request(all_requests, tokenizer_manager):
prompts = [] prompts = []
sampling_params_list = [] sampling_params_list = []
first_prompt_type = type(all_requests[0].input) first_prompt_type = type(all_requests[0].input)
...@@ -1509,13 +1516,13 @@ def v1_embedding_response(ret, model_path, to_file=False): ...@@ -1509,13 +1516,13 @@ def v1_embedding_response(ret, model_path, to_file=False):
) )
async def v1_embeddings(orchestrator, raw_request: Request): async def v1_embeddings(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
all_requests = [EmbeddingRequest(**request_json)] all_requests = [EmbeddingRequest(**request_json)]
adapted_request, request = v1_embedding_request(all_requests, orchestrator) adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
try: try:
ret = await orchestrator.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
).__anext__() ).__anext__()
except ValueError as e: except ValueError as e:
...@@ -1524,7 +1531,7 @@ async def v1_embeddings(orchestrator, raw_request: Request): ...@@ -1524,7 +1531,7 @@ async def v1_embeddings(orchestrator, raw_request: Request):
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
response = v1_embedding_response(ret, orchestrator.model_path) response = v1_embedding_response(ret, tokenizer_manager.model_path)
return response return response
......
...@@ -1035,7 +1035,7 @@ class PortArgs: ...@@ -1035,7 +1035,7 @@ class PortArgs:
if dp_rank is None: if dp_rank is None:
scheduler_input_port = ( scheduler_input_port = (
port_base + 2 port_base + 2
) # StdOrchestrator to DataParallelController ) # TokenizerManager to DataParallelController
else: else:
scheduler_input_port = port_base + 2 + 1 + dp_rank scheduler_input_port = port_base + 2 + 1 + dp_rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment