Unverified Commit 9a87b057 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Feat] Supports Anthropic Messages count_tokens API (#35588)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 510bc9e1
...@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request ...@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicError, AnthropicError,
AnthropicErrorResponse, AnthropicErrorResponse,
AnthropicMessagesRequest, AnthropicMessagesRequest,
...@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages: ...@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages return request.app.state.anthropic_serving_messages
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
@router.post( @router.post(
"/v1/messages", "/v1/messages",
dependencies=[Depends(validate_json_request)], dependencies=[Depends(validate_json_request)],
...@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages: ...@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request) handler = messages(raw_request)
if handler is None: if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization base_server = raw_request.app.state.openai_serving_tokenization
...@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques ...@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/messages/count_tokens",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@load_aware_call
@with_cancellation
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
error = base_server.create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
response = await handler.count_tokens(request, raw_request)
except Exception as e:
logger.exception("Error in count_tokens: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(response, ErrorResponse):
return translate_error_response(response)
return JSONResponse(content=response.model_dump(exclude_none=True))
def attach_router(app: FastAPI): def attach_router(app: FastAPI):
app.include_router(router) app.include_router(router)
...@@ -175,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel): ...@@ -175,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context): def model_post_init(self, __context):
if not self.id: if not self.id:
self.id = f"msg_{int(time.time() * 1000)}" self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None
...@@ -17,6 +17,9 @@ from fastapi import Request ...@@ -17,6 +17,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
AnthropicContentBlock, AnthropicContentBlock,
AnthropicContextManagement,
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicDelta, AnthropicDelta,
AnthropicError, AnthropicError,
AnthropicMessagesRequest, AnthropicMessagesRequest,
...@@ -109,135 +112,202 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -109,135 +112,202 @@ class AnthropicServingMessages(OpenAIServingChat):
@classmethod @classmethod
def _convert_anthropic_to_openai_request( def _convert_anthropic_to_openai_request(
cls, anthropic_request: AnthropicMessagesRequest cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
) -> ChatCompletionRequest: ) -> ChatCompletionRequest:
"""Convert Anthropic message format to OpenAI format""" """Convert Anthropic message format to OpenAI format"""
openai_messages = [] openai_messages: list[dict[str, Any]] = []
cls._convert_system_message(anthropic_request, openai_messages)
cls._convert_messages(anthropic_request.messages, openai_messages)
req = cls._build_base_request(anthropic_request, openai_messages)
cls._handle_streaming_options(req, anthropic_request)
cls._convert_tool_choice(anthropic_request, req)
cls._convert_tools(anthropic_request, req)
return req
# Add system message if provided @classmethod
if anthropic_request.system: def _convert_system_message(
if isinstance(anthropic_request.system, str): cls,
openai_messages.append( anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
{"role": "system", "content": anthropic_request.system} openai_messages: list[dict[str, Any]],
) ) -> None:
else: """Convert Anthropic system message to OpenAI format"""
system_prompt = "" if not anthropic_request.system:
for block in anthropic_request.system: return
if block.type == "text" and block.text:
system_prompt += block.text if isinstance(anthropic_request.system, str):
openai_messages.append({"role": "system", "content": system_prompt}) openai_messages.append(
{"role": "system", "content": anthropic_request.system}
)
else:
system_prompt = ""
for block in anthropic_request.system:
if block.type == "text" and block.text:
system_prompt += block.text
openai_messages.append({"role": "system", "content": system_prompt})
for msg in anthropic_request.messages: @classmethod
def _convert_messages(
cls, messages: list, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert Anthropic messages to OpenAI format"""
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
if isinstance(msg.content, str): if isinstance(msg.content, str):
openai_msg["content"] = msg.content openai_msg["content"] = msg.content
else: else:
# Handle complex content blocks cls._convert_message_content(msg, openai_msg, openai_messages)
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = [] openai_messages.append(openai_msg)
reasoning_parts: list[str] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
image_url = cls._convert_image_source_to_url(block.source)
content_parts.append(
{
"type": "image_url",
"image_url": {"url": image_url},
}
)
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
# Parse tool_result content which can be
# a string or a list of content blocks
# (text, image, etc.)
tool_text = ""
tool_image_urls: list[str] = []
if isinstance(block.content, str):
tool_text = block.content
elif isinstance(block.content, list):
text_parts: list[str] = []
for item in block.content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text_parts.append(item.get("text", ""))
elif item_type == "image":
source = item.get("source", {})
url = cls._convert_image_source_to_url(source)
if url:
tool_image_urls.append(url)
tool_text = "\n".join(text_parts)
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.tool_use_id or "",
"content": tool_text or "",
}
)
# OpenAI tool messages only support string
# content, so inject images from tool
# results as a follow-up user message
if tool_image_urls:
openai_messages.append(
{
"role": "user",
"content": [ # type: ignore[dict-item]
{
"type": "image_url",
"image_url": {"url": img},
}
for img in tool_image_urls
],
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
if reasoning_parts: @classmethod
openai_msg["reasoning"] = "".join(reasoning_parts) def _convert_message_content(
cls,
msg,
openai_msg: dict[str, Any],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert complex message content blocks"""
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
cls._convert_block(
block,
msg.role,
content_parts,
tool_calls,
reasoning_parts,
openai_messages,
)
# Add tool calls to the message if any if reasoning_parts:
if tool_calls: openai_msg["reasoning"] = "".join(reasoning_parts)
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any if tool_calls:
if content_parts: openai_msg["tool_calls"] = tool_calls # type: ignore
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"] if content_parts:
else: if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts # type: ignore openai_msg["content"] = content_parts[0]["text"]
elif not tool_calls and not reasoning_parts: else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
return
@classmethod
def _convert_block(
cls,
block,
role: str,
content_parts: list[dict[str, Any]],
tool_calls: list[dict[str, Any]],
reasoning_parts: list[str],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert individual content block"""
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
image_url = cls._convert_image_source_to_url(block.source)
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
cls._convert_tool_use_block(block, tool_calls)
elif block.type == "tool_result":
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
@classmethod
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
"""Convert tool_use block to OpenAI function call format"""
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
@classmethod
def _convert_tool_result_block(
cls,
block,
role: str,
openai_messages: list[dict[str, Any]],
content_parts: list[dict[str, Any]],
) -> None:
"""Convert tool_result block to OpenAI format"""
if role == "user":
cls._convert_user_tool_result(block, openai_messages)
else:
tool_result_text = str(block.content) if block.content else ""
content_parts.append(
{"type": "text", "text": f"Tool result: {tool_result_text}"}
)
@classmethod
def _convert_user_tool_result(
cls, block, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert user tool_result with text and image support"""
tool_text = ""
tool_image_urls: list[str] = []
if isinstance(block.content, str):
tool_text = block.content
elif isinstance(block.content, list):
text_parts: list[str] = []
for item in block.content:
if not isinstance(item, dict):
continue continue
item_type = item.get("type")
if item_type == "text":
text_parts.append(item.get("text", ""))
elif item_type == "image":
source = item.get("source", {})
url = cls._convert_image_source_to_url(source)
if url:
tool_image_urls.append(url)
tool_text = "\n".join(text_parts)
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.tool_use_id or "",
"content": tool_text or "",
}
)
openai_messages.append(openai_msg) if tool_image_urls:
openai_messages.append(
{
"role": "user",
"content": [ # type: ignore[dict-item]
{"type": "image_url", "image_url": {"url": img}}
for img in tool_image_urls
],
}
)
req = ChatCompletionRequest( @classmethod
def _build_base_request(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
openai_messages: list[dict[str, Any]],
) -> ChatCompletionRequest:
"""Build base ChatCompletionRequest"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
)
return ChatCompletionRequest(
model=anthropic_request.model, model=anthropic_request.model,
messages=openai_messages, messages=openai_messages,
max_tokens=anthropic_request.max_tokens, max_tokens=anthropic_request.max_tokens,
...@@ -248,19 +318,38 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -248,19 +318,38 @@ class AnthropicServingMessages(OpenAIServingChat):
top_k=anthropic_request.top_k, top_k=anthropic_request.top_k,
) )
@classmethod
def _handle_streaming_options(
cls,
req: ChatCompletionRequest,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
) -> None:
"""Handle streaming configuration"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return
if anthropic_request.stream: if anthropic_request.stream:
req.stream = anthropic_request.stream req.stream = anthropic_request.stream
req.stream_options = StreamOptions.validate( req.stream_options = StreamOptions.model_validate(
{"include_usage": True, "continuous_usage_stats": True} {"include_usage": True, "continuous_usage_stats": True}
) )
@classmethod
def _convert_tool_choice(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tool_choice to OpenAI format"""
if anthropic_request.tool_choice is None: if anthropic_request.tool_choice is None:
req.tool_choice = None req.tool_choice = None
elif anthropic_request.tool_choice.type == "auto": return
tool_choice_type = anthropic_request.tool_choice.type
if tool_choice_type == "auto":
req.tool_choice = "auto" req.tool_choice = "auto"
elif anthropic_request.tool_choice.type == "any": elif tool_choice_type == "any":
req.tool_choice = "required" req.tool_choice = "required"
elif anthropic_request.tool_choice.type == "tool": elif tool_choice_type == "tool":
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
{ {
"type": "function", "type": "function",
...@@ -268,9 +357,17 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -268,9 +357,17 @@ class AnthropicServingMessages(OpenAIServingChat):
} }
) )
tools = [] @classmethod
def _convert_tools(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tools to OpenAI format"""
if anthropic_request.tools is None: if anthropic_request.tools is None:
return req return
tools = []
for tool in anthropic_request.tools: for tool in anthropic_request.tools:
tools.append( tools.append(
ChatCompletionToolsParam.model_validate( ChatCompletionToolsParam.model_validate(
...@@ -284,10 +381,10 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -284,10 +381,10 @@ class AnthropicServingMessages(OpenAIServingChat):
} }
) )
) )
if req.tool_choice is None: if req.tool_choice is None:
req.tool_choice = "auto" req.tool_choice = "auto"
req.tools = tools req.tools = tools
return req
async def create_messages( async def create_messages(
self, self,
...@@ -670,3 +767,31 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -670,3 +767,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True) data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error") yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def count_tokens(
self,
request: AnthropicCountTokensRequest,
raw_request: Request | None = None,
) -> AnthropicCountTokensResponse | ErrorResponse:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req = self._convert_anthropic_to_openai_request(request)
result = await self.render_chat_request(chat_req)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts
if "prompt_token_ids" in prompt
)
response = AnthropicCountTokensResponse(
input_tokens=input_tokens,
context_management=AnthropicContextManagement(
original_input_tokens=input_tokens
),
)
return response
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment