Unverified Commit d5de20a3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix `sync()` when `fork(1)` (#412)

parent 4a1c6ae2
...@@ -232,9 +232,15 @@ register_chat_template( ...@@ -232,9 +232,15 @@ register_chat_template(
name="c4ai-command-r", name="c4ai-command-r",
default_system_prompt=None, default_system_prompt=None,
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "system": (
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": ("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "assistant": (
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
) )
......
"""The interpreter that executes SGL programs""" """The interpreter that executes SGL programs"""
import asyncio import asyncio
import contextvars
import multiprocessing import multiprocessing
import queue import queue
import threading import threading
...@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor ...@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import contextvars
import tqdm import tqdm
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -222,7 +222,10 @@ class StreamExecutor: ...@@ -222,7 +222,10 @@ class StreamExecutor:
def _run_worker_in_context(): def _run_worker_in_context():
self._thread_worker_func() self._thread_worker_func()
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
self.worker = threading.Thread(
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
)
self.worker.start() self.worker.start()
# For streaming # For streaming
...@@ -265,12 +268,11 @@ class StreamExecutor: ...@@ -265,12 +268,11 @@ class StreamExecutor:
self, self,
number: int, number: int,
position_ids_offset: Optional[List[int]] = None, position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
): ):
if number > 1 or copy: if number > 1:
self.submit(SglCommitLazy()) self.submit(SglCommitLazy())
self.sync()
self.sync()
number = int(number) number = int(number)
exes = [ exes = [
...@@ -656,16 +658,15 @@ class ProgramState: ...@@ -656,16 +658,15 @@ class ProgramState:
self, self,
number: int = 1, number: int = 1,
position_ids_offset: Optional[List[int]] = None, position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
): ):
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy) stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors] states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self) state_group = ProgramStateGroup(states, self)
return state_group return state_group
@contextmanager @contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None): def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset, True) state_group = self.fork(1, position_ids_offset)
try: try:
yield state_group[0] yield state_group[0]
finally: finally:
......
...@@ -42,7 +42,9 @@ class DetokenizerManager: ...@@ -42,7 +42,9 @@ class DetokenizerManager:
output_strs = self.tokenizer.batch_decode( output_strs = self.tokenizer.batch_decode(
output_tokens, output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
) )
# Trim stop str # Trim stop str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment