Unverified Commit c9bf77df authored by yzong-rh's avatar yzong-rh Committed by GitHub
Browse files

[BUG]: fix HF tokenizer concurrent borrow in tool parsers (#40059)


Signed-off-by: default avatarYifan <yzong@redhat.com>
Co-authored-by: default avatartimon0305 <timon0305@outlook.com>
Co-authored-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 30413442
...@@ -4,15 +4,22 @@ ...@@ -4,15 +4,22 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from transformers import AutoTokenizer
from vllm.entrypoints.openai.engine.protocol import ExtractedToolCallInformation from vllm.entrypoints.openai.engine.protocol import ExtractedToolCallInformation
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
LLAMA_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module")
def llama_tokenizer():
return AutoTokenizer.from_pretrained(LLAMA_MODEL)
@pytest.fixture @pytest.fixture
def parser(default_tokenizer: TokenizerLike): def parser(llama_tokenizer):
return Llama3JsonToolParser(default_tokenizer) return Llama3JsonToolParser(llama_tokenizer)
def test_extract_tool_calls_simple(parser): def test_extract_tool_calls_simple(parser):
......
...@@ -34,41 +34,29 @@ class FunctionGemmaToolParser(ToolParser): ...@@ -34,41 +34,29 @@ class FunctionGemmaToolParser(ToolParser):
<start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call> <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
""" """
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Streaming state
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
# FunctionGemma tokens # FunctionGemma tokens
self.tool_call_start_token: str = "<start_function_call>" tool_call_start_token: str = "<start_function_call>"
self.tool_call_end_token: str = "<end_function_call>" tool_call_end_token: str = "<end_function_call>"
# Regex patterns # Regex patterns
self.tool_call_regex = re.compile( tool_call_regex: re.Pattern = re.compile(
r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>" r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>"
r"|<start_function_call>call:(\w+)\{(.*)", r"|<start_function_call>call:(\w+)\{(.*)",
re.DOTALL, re.DOTALL,
) )
self.arg_regex = re.compile( arg_regex: re.Pattern = re.compile(
r"(\w+):<escape>(.*?)<escape>", r"(\w+):<escape>(.*?)<escape>",
re.DOTALL, re.DOTALL,
) )
if self.model_tokenizer: def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
self.tool_call_start_token_ids = self.model_tokenizer.encode( super().__init__(tokenizer, tools)
self.tool_call_start_token, add_special_tokens=False
)
self.tool_call_end_token_ids = self.model_tokenizer.encode(
self.tool_call_end_token, add_special_tokens=False
)
else:
self.tool_call_start_token_ids = []
self.tool_call_end_token_ids = []
# Streaming state
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
self.buffered_delta_text = "" self.buffered_delta_text = ""
def _parse_arguments(self, args_str: str) -> dict: def _parse_arguments(self, args_str: str) -> dict:
......
...@@ -45,6 +45,12 @@ class Llama3JsonToolParser(ToolParser): ...@@ -45,6 +45,12 @@ class Llama3JsonToolParser(ToolParser):
llama4_json are set. llama4_json are set.
""" """
bot_token: str = "<|python_tag|>"
# Simple regex to find opening braces - we'll use JSON decoder for parsing
# This handles arbitrary nesting depth correctly
tool_call_start_regex: re.Pattern = re.compile(r"\{")
json_decoder: json.JSONDecoder = json.JSONDecoder()
def __init__( def __init__(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
...@@ -60,14 +66,12 @@ class Llama3JsonToolParser(ToolParser): ...@@ -60,14 +66,12 @@ class Llama3JsonToolParser(ToolParser):
self.streamed_args_for_tool: list[ self.streamed_args_for_tool: list[
str str
] = [] # map what has been streamed for each tool so far to a list ] = [] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>" self.bot_token_id = self.vocab.get(self.bot_token)
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[ if self.bot_token_id is None:
0 raise RuntimeError(
] "Llama3JsonToolParser could not locate the bot token "
# Simple regex to find opening braces - we'll use JSON decoder for parsing f"'{self.bot_token}' in the tokenizer."
# This handles arbitrary nesting depth correctly )
self.tool_call_start_regex = re.compile(r"\{")
self.json_decoder = json.JSONDecoder()
def extract_tool_calls( def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment