Unverified Commit 7c38ed0f authored by Andrew Xia's avatar Andrew Xia Committed by GitHub
Browse files

[Frontend] split append tool output (#28333)


Signed-off-by: default avatarAndrew Xia <axia@fb.com>
Co-authored-by: default avatarAndrew Xia <axia@fb.com>
parent a1d3866d
...@@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext): ...@@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext):
def append_output(self, output) -> None: def append_output(self, output) -> None:
pass pass
def append_tool_output(self, output) -> None:
pass
async def call_tool(self): async def call_tool(self):
return [] return []
......
...@@ -80,7 +80,11 @@ class TurnMetrics: ...@@ -80,7 +80,11 @@ class TurnMetrics:
class ConversationContext(ABC): class ConversationContext(ABC):
@abstractmethod @abstractmethod
def append_output(self, output) -> None: def append_output(self, output: RequestOutput) -> None:
pass
@abstractmethod
def append_tool_output(self, output) -> None:
pass pass
@abstractmethod @abstractmethod
...@@ -151,6 +155,9 @@ class SimpleContext(ConversationContext): ...@@ -151,6 +155,9 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
return False return False
...@@ -205,28 +212,28 @@ class HarmonyContext(ConversationContext): ...@@ -205,28 +212,28 @@ class HarmonyContext(ConversationContext):
if self.parser.current_channel in {"analysis", "commentary"}: if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1 self.num_reasoning_tokens += 1
def append_output(self, output: RequestOutput | list[Message]) -> None: def append_output(self, output: RequestOutput) -> None:
if isinstance(output, RequestOutput): output_token_ids = output.outputs[0].token_ids
output_token_ids = output.outputs[0].token_ids self.parser = get_streamable_parser_for_assistant()
self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids:
for token_id in output_token_ids: self.parser.process(token_id)
self.parser.process(token_id) # Check if the current token is part of reasoning content
# Check if the current token is part of reasoning content self._update_num_reasoning_tokens()
self._update_num_reasoning_tokens() self._update_prefill_token_usage(output)
self._update_prefill_token_usage(output) self._update_decode_token_usage(output)
self._update_decode_token_usage(output) # Append current turn to all turn list for next turn's calculations
# Append current turn to all turn list for next turn's calculations self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.all_turn_metrics.append(self.current_turn_metrics.copy()) self.current_turn_metrics.reset()
self.current_turn_metrics.reset() # append_output is called only once before tool calling
# append_output is called only once before tool calling # in non-streaming case
# in non-streaming case # so we can append all the parser messages to _messages
# so we can append all the parser messages to _messages output_msgs = self.parser.messages
output_msgs = self.parser.messages # The responses finish reason is set in the last message
# The responses finish reason is set in the last message self.finish_reason = output.outputs[0].finish_reason
self.finish_reason = output.outputs[0].finish_reason self._messages.extend(output_msgs)
else:
# Tool output. def append_tool_output(self, output: list[Message]) -> None:
output_msgs = output output_msgs = output
self._messages.extend(output_msgs) self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None: def _update_prefill_token_usage(self, output: RequestOutput) -> None:
...@@ -502,45 +509,45 @@ class StreamingHarmonyContext(HarmonyContext): ...@@ -502,45 +509,45 @@ class StreamingHarmonyContext(HarmonyContext):
def messages(self) -> list: def messages(self) -> list:
return self._messages return self._messages
def append_output(self, output: RequestOutput | list[Message]) -> None: def append_output(self, output: RequestOutput) -> None:
if isinstance(output, RequestOutput): # append_output is called for each output token in streaming case,
# append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message.
# so we only want to add the prompt tokens once for each message. if self.first_tok_of_message:
if self.first_tok_of_message: self._update_prefill_token_usage(output)
self._update_prefill_token_usage(output) # Reset self.first_tok_of_message if needed:
# Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message
# if the current token is the last one of the current message # (finished=True), then the next token processed will mark the
# (finished=True), then the next token processed will mark the # beginning of a new message
# beginning of a new message self.first_tok_of_message = output.finished
self.first_tok_of_message = output.finished for tok in output.outputs[0].token_ids:
for tok in output.outputs[0].token_ids: self.parser.process(tok)
self.parser.process(tok) self._update_decode_token_usage(output)
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
# For streaming, update previous turn when message is complete if output.finished:
if output.finished: self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.all_turn_metrics.append(self.current_turn_metrics.copy()) self.current_turn_metrics.reset()
self.current_turn_metrics.reset() # Check if the current token is part of reasoning content
# Check if the current token is part of reasoning content self._update_num_reasoning_tokens()
self._update_num_reasoning_tokens() self.last_tok = tok
self.last_tok = tok if len(self._messages) - self.num_init_messages < len(self.parser.messages):
if len(self._messages) - self.num_init_messages < len(self.parser.messages): self._messages.extend(
self._messages.extend( self.parser.messages[len(self._messages) - self.num_init_messages :]
self.parser.messages[len(self._messages) - self.num_init_messages :] )
)
else: def append_tool_output(self, output: list[Message]) -> None:
# Handle the case of tool output in direct message format # Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message" assert len(output) == 1, "Tool output should be a single message"
msg = output[0] msg = output[0]
# Sometimes the recipient is not set for tool messages, # Sometimes the recipient is not set for tool messages,
# so we set it to "assistant" # so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None: if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant" msg.recipient = "assistant"
toks = self.encoding.render(msg) toks = self.encoding.render(msg)
for tok in toks: for tok in toks:
self.parser.process(tok) self.parser.process(tok)
self.last_tok = toks[-1] self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages # TODO: add tool_output messages to self._messages
def is_expecting_start(self) -> bool: def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START return self.parser.state == StreamState.EXPECT_START
......
...@@ -1227,7 +1227,7 @@ class OpenAIServing: ...@@ -1227,7 +1227,7 @@ class OpenAIServing:
# Call the tool and update the context with the result. # Call the tool and update the context with the result.
tool_output = await context.call_tool() tool_output = await context.call_tool()
context.append_output(tool_output) context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming # TODO: uncomment this and enable tool output streaming
# yield context # yield context
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment