"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b3e99dfb2292ee9de83ca1a29800dff900da19af"
Unverified Commit 1bf1cf19 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Reduce overhead when `fork(1)` (#375)

parent e822e590
...@@ -256,9 +256,15 @@ class StreamExecutor: ...@@ -256,9 +256,15 @@ class StreamExecutor:
ret = self.meta_info.get(name, None) ret = self.meta_info.get(name, None)
return ret return ret
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None): def fork(
self.submit(SglCommitLazy()) self,
self.sync() number: int,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
if number > 1 or copy:
self.submit(SglCommitLazy())
self.sync()
number = int(number) number = int(number)
...@@ -641,15 +647,20 @@ class ProgramState: ...@@ -641,15 +647,20 @@ class ProgramState:
yield yield
self.stream_executor.submit(SglVarScopeEnd(name)) self.stream_executor.submit(SglVarScopeEnd(name))
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None): def fork(
stream_executors = self.stream_executor.fork(number, position_ids_offset) self,
number: int = 1,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
states = [ProgramState(x) for x in stream_executors] states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self) state_group = ProgramStateGroup(states, self)
return state_group return state_group
@contextmanager @contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None): def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset) state_group = self.fork(1, position_ids_offset, True)
try: try:
yield state_group[0] yield state_group[0]
finally: finally:
......
...@@ -179,7 +179,9 @@ class RadixCache: ...@@ -179,7 +179,9 @@ class RadixCache:
def _print_helper(self, node, indent): def _print_helper(self, node, indent):
for _, child in node.children.items(): for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}") print(
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
)
self._print_helper(child, indent=indent + 2) self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node): def _delete_leaf(self, node):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment