"tutorials/models/vscode:/vscode.git/clone" did not exist on "f370e628cdf6dcff6f32392a35e1789be68630a8"
Unverified Commit d5de20a3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix `sync()` when `fork(1)` (#412)

parent 4a1c6ae2
......@@ -232,9 +232,15 @@ register_chat_template(
name="c4ai-command-r",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"system": (
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": ("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": (
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
},
style=ChatTemplateStyle.PLAIN,
)
......
"""The interpreter that executes SGL programs"""
import asyncio
import contextvars
import multiprocessing
import queue
import threading
......@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import contextvars
import tqdm
from sglang.global_config import global_config
......@@ -222,7 +222,10 @@ class StreamExecutor:
def _run_worker_in_context():
self._thread_worker_func()
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
self.worker = threading.Thread(
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
)
self.worker.start()
# For streaming
......@@ -265,12 +268,11 @@ class StreamExecutor:
self,
number: int,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
if number > 1 or copy:
if number > 1:
self.submit(SglCommitLazy())
self.sync()
self.sync()
number = int(number)
exes = [
......@@ -656,16 +658,15 @@ class ProgramState:
self,
number: int = 1,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset, True)
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally:
......
......@@ -42,7 +42,9 @@ class DetokenizerManager:
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
)
# Trim stop str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment