Unverified Commit 0892d1ab authored by Mario Hong's avatar Mario Hong Committed by GitHub
Browse files

[Feature]Supports Anthropic Thinking Block (#33671)


Signed-off-by: default avatarmariohong <mariohong128@gmail.com>
Co-authored-by: default avatarzetaohong <i-hongzetao@stepfun.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 7600642e
...@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel): ...@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
class AnthropicContentBlock(BaseModel): class AnthropicContentBlock(BaseModel):
"""Content block in message""" """Content block in message"""
type: Literal["text", "image", "tool_use", "tool_result"] type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
text: str | None = None text: str | None = None
# For image content # For image content
source: dict[str, Any] | None = None source: dict[str, Any] | None = None
...@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel): ...@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
input: dict[str, Any] | None = None input: dict[str, Any] | None = None
content: str | list[dict[str, Any]] | None = None content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None is_error: bool | None = None
# For thinking content
thinking: str | None = None
signature: str | None = None
class AnthropicMessage(BaseModel): class AnthropicMessage(BaseModel):
...@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel): ...@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
class AnthropicDelta(BaseModel): class AnthropicDelta(BaseModel):
"""Delta for streaming responses""" """Delta for streaming responses"""
type: Literal["text_delta", "input_json_delta"] | None = None type: (
Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
| None
) = None
text: str | None = None text: str | None = None
thinking: str | None = None
partial_json: str | None = None partial_json: str | None = None
signature: str | None = None
# Message delta # Message delta
stop_reason: ( stop_reason: (
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
import json import json
import logging import logging
import time import time
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
...@@ -112,6 +113,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -112,6 +113,7 @@ class AnthropicServingMessages(OpenAIServingChat):
# Handle complex content blocks # Handle complex content blocks
content_parts: list[dict[str, Any]] = [] content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content: for block in msg.content:
if block.type == "text" and block.text: if block.type == "text" and block.text:
...@@ -123,6 +125,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -123,6 +125,8 @@ class AnthropicServingMessages(OpenAIServingChat):
"image_url": {"url": block.source.get("data", "")}, "image_url": {"url": block.source.get("data", "")},
} }
) )
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use": elif block.type == "tool_use":
# Convert tool use to function call format # Convert tool use to function call format
tool_call = { tool_call = {
...@@ -157,6 +161,9 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -157,6 +161,9 @@ class AnthropicServingMessages(OpenAIServingChat):
} }
) )
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any # Add tool calls to the message if any
if tool_calls: if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore openai_msg["tool_calls"] = tool_calls # type: ignore
...@@ -167,7 +174,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -167,7 +174,7 @@ class AnthropicServingMessages(OpenAIServingChat):
openai_msg["content"] = content_parts[0]["text"] openai_msg["content"] = content_parts[0]["text"]
else: else:
openai_msg["content"] = content_parts # type: ignore openai_msg["content"] = content_parts # type: ignore
elif not tool_calls: elif not tool_calls and not reasoning_parts:
continue continue
openai_messages.append(openai_msg) openai_messages.append(openai_msg)
...@@ -263,23 +270,32 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -263,23 +270,32 @@ class AnthropicServingMessages(OpenAIServingChat):
output_tokens=generator.usage.completion_tokens, output_tokens=generator.usage.completion_tokens,
), ),
) )
if generator.choices[0].finish_reason == "stop": choice = generator.choices[0]
if choice.finish_reason == "stop":
result.stop_reason = "end_turn" result.stop_reason = "end_turn"
elif generator.choices[0].finish_reason == "length": elif choice.finish_reason == "length":
result.stop_reason = "max_tokens" result.stop_reason = "max_tokens"
elif generator.choices[0].finish_reason == "tool_calls": elif choice.finish_reason == "tool_calls":
result.stop_reason = "tool_use" result.stop_reason = "tool_use"
content: list[AnthropicContentBlock] = [ content: list[AnthropicContentBlock] = []
if choice.message.reasoning:
content.append(
AnthropicContentBlock(
type="thinking",
thinking=choice.message.reasoning,
signature=uuid.uuid4().hex,
)
)
if choice.message.content:
content.append(
AnthropicContentBlock( AnthropicContentBlock(
type="text", type="text",
text=generator.choices[0].message.content text=choice.message.content,
if generator.choices[0].message.content )
else "",
) )
]
for tool_call in generator.choices[0].message.tool_calls: for tool_call in choice.message.tool_calls:
anthropic_tool_call = AnthropicContentBlock( anthropic_tool_call = AnthropicContentBlock(
type="tool_use", type="tool_use",
id=tool_call.id, id=tool_call.id,
...@@ -297,10 +313,85 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -297,10 +313,85 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None], generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
try: try:
class _ActiveBlockState:
def __init__(self) -> None:
self.content_block_index = 0
self.block_type: str | None = None
self.block_index: int | None = None
self.block_signature: str | None = None
self.signature_emitted: bool = False
self.tool_use_id: str | None = None
def reset(self) -> None:
self.block_type = None
self.block_index = None
self.block_signature = None
self.signature_emitted = False
self.tool_use_id = None
def start(self, block: AnthropicContentBlock) -> None:
self.block_type = block.type
self.block_index = self.content_block_index
if block.type == "thinking":
self.block_signature = uuid.uuid4().hex
self.signature_emitted = False
self.tool_use_id = None
elif block.type == "tool_use":
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = block.id
else:
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = None
first_item = True first_item = True
finish_reason = None finish_reason = None
content_block_index = 0 state = _ActiveBlockState()
content_block_started = False # Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
events: list[str] = []
if state.block_type is None:
return events
if (
state.block_type == "thinking"
and state.block_signature is not None
and not state.signature_emitted
):
chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=state.block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset()
state.content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
chunk = AnthropicStreamEvent(
index=state.content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
state.start(block)
return event
async for item in generator: async for item in generator:
if item.startswith("data:"): if item.startswith("data:"):
...@@ -326,6 +417,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -326,6 +417,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id, id=origin_chunk.id,
content=[], content=[],
model=origin_chunk.model, model=origin_chunk.model,
stop_reason=None,
stop_sequence=None,
usage=AnthropicUsage( usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage if origin_chunk.usage
...@@ -341,13 +434,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -341,13 +434,8 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info # last chunk including usage info
if len(origin_chunk.choices) == 0: if len(origin_chunk.choices) == 0:
if content_block_started: for event in stop_active_block():
stop_chunk = AnthropicStreamEvent( yield event
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop")
stop_reason = self.stop_reason_map.get( stop_reason = self.stop_reason_map.get(
finish_reason or "stop" finish_reason or "stop"
) )
...@@ -369,26 +457,55 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -369,26 +457,55 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].finish_reason is not None: if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason finish_reason = origin_chunk.choices[0].finish_reason
continue # continue
# content # thinking / text content
if origin_chunk.choices[0].delta.content is not None: reasoning_delta = origin_chunk.choices[0].delta.reasoning
if not content_block_started: if reasoning_delta is not None:
if reasoning_delta == "":
pass
else:
if state.block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="thinking", thinking=""
)
)
yield start_event
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
type="content_block_start", state.block_index
content_block=AnthropicContentBlock( if state.block_index is not None
type="text", text="" else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="thinking_delta",
thinking=reasoning_delta,
), ),
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start") yield wrap_data_with_event(data, "content_block_delta")
content_block_started = True
if origin_chunk.choices[0].delta.content is not None:
if origin_chunk.choices[0].delta.content == "": if origin_chunk.choices[0].delta.content == "":
continue pass
else:
if state.block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(type="text", text="")
)
yield start_event
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
type="text_delta", type="text_delta",
...@@ -397,44 +514,47 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -397,44 +514,47 @@ class AnthropicServingMessages(OpenAIServingChat):
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta") yield wrap_data_with_event(data, "content_block_delta")
continue
# tool calls # tool calls - process all tool calls in the delta
elif len(origin_chunk.choices[0].delta.tool_calls) > 0: if len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0] for tool_call in origin_chunk.choices[0].delta.tool_calls:
if tool_call.id is not None: if tool_call.id is not None:
if content_block_started: # Update mapping for incremental updates
stop_chunk = AnthropicStreamEvent( tool_index_to_id[tool_call.index] = tool_call.id
index=content_block_index, # Only create new block if different tool call
type="content_block_stop", # AND has a name
) tool_name = (
data = stop_chunk.model_dump_json( tool_call.function.name
exclude_unset=True if tool_call.function
) else None
yield wrap_data_with_event(
data, "content_block_stop"
) )
content_block_started = False if (
content_block_index += 1 state.tool_use_id != tool_call.id
and tool_name is not None
chunk = AnthropicStreamEvent( ):
index=content_block_index, for event in stop_active_block():
type="content_block_start", yield event
content_block=AnthropicContentBlock( start_event = start_block(
AnthropicContentBlock(
type="tool_use", type="tool_use",
id=tool_call.id, id=tool_call.id,
name=tool_call.function.name name=tool_name,
if tool_call.function
else None,
input={}, input={},
),
) )
data = chunk.model_dump_json(exclude_unset=True) )
yield wrap_data_with_event(data, "content_block_start") yield start_event
content_block_started = True # Handle initial arguments if present
if tool_call.function and tool_call.function.arguments: if (
tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
type="input_json_delta", type="input_json_delta",
...@@ -445,20 +565,31 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -445,20 +565,31 @@ class AnthropicServingMessages(OpenAIServingChat):
yield wrap_data_with_event( yield wrap_data_with_event(
data, "content_block_delta" data, "content_block_delta"
) )
else: else:
# Incremental update - use index to find tool_use_id
tool_use_id = tool_index_to_id.get(tool_call.index)
if (
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
type="input_json_delta", type="input_json_delta",
partial_json=tool_call.function.arguments partial_json=tool_call.function.arguments,
if tool_call.function
else None,
), ),
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta") yield wrap_data_with_event(
data, "content_block_delta"
)
continue continue
else: else:
error_response = AnthropicStreamEvent( error_response = AnthropicStreamEvent(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment