Unverified Commit daf593a3 authored by ObjectNotFound's avatar ObjectNotFound Committed by GitHub
Browse files

Fix streaming bug (#820)

parent bece265f
...@@ -553,6 +553,7 @@ class StreamExecutor: ...@@ -553,6 +553,7 @@ class StreamExecutor:
"output_token_logprobs": output_token_logprobs, "output_token_logprobs": output_token_logprobs,
} }
self.variable_event[name].set() self.variable_event[name].set()
self.stream_var_event[name].set()
self.text_ += decision self.text_ += decision
def _execute_variable(self, expr: SglVariable): def _execute_variable(self, expr: SglVariable):
...@@ -778,7 +779,14 @@ class ProgramState: ...@@ -778,7 +779,14 @@ class ProgramState:
if self.stream_executor.is_finished: if self.stream_executor.is_finished:
break break
else: else:
event = self.stream_executor.stream_var_event[var_name] event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return
while True: while True:
event.wait() event.wait()
event.clear() event.clear()
...@@ -813,7 +821,14 @@ class ProgramState: ...@@ -813,7 +821,14 @@ class ProgramState:
if self.stream_executor.is_finished: if self.stream_executor.is_finished:
break break
else: else:
event = self.stream_executor.stream_var_event[var_name] event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return
while True: while True:
await loop.run_in_executor(None, event.wait) await loop.run_in_executor(None, event.wait)
event.clear() event.clear()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment