Unverified Commit 293f036e authored by Viacheslav's avatar Viacheslav Committed by GitHub
Browse files

Add gigachat 3.1 tool parser + fix gigachat3 tool parser (#36664)


Signed-off-by: default avatarViacheslav Barinov <viacheslav.teh@gmail.com>
parent 0fb142a4
...@@ -13,6 +13,13 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall ...@@ -13,6 +13,13 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tool_parsers import ToolParser, ToolParserManager
MSG_SEP_TOKEN = "<|message_sep|>\n\n"
ROLE_SEP_TOKEN = "<|role_sep|>\n"
EOS_TOKEN = "</s>"
TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}"
TOOL_HEADER_GIGACHAT31 = "<|function_call|>"
SIMPLE_ARGS_DICT = { SIMPLE_ARGS_DICT = {
"action": "create", "action": "create",
"id": "preferences", "id": "preferences",
...@@ -24,7 +31,10 @@ SIMPLE_FUNCTION_JSON = json.dumps( ...@@ -24,7 +31,10 @@ SIMPLE_FUNCTION_JSON = json.dumps(
}, },
ensure_ascii=False, ensure_ascii=False,
) )
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}"
)
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}"
SIMPLE_FUNCTION_CALL = FunctionCall( SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory", name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False), arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
...@@ -38,7 +48,12 @@ PARAMETERLESS_FUNCTION_JSON = json.dumps( ...@@ -38,7 +48,12 @@ PARAMETERLESS_FUNCTION_JSON = json.dumps(
}, },
ensure_ascii=False, ensure_ascii=False,
) )
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}"
)
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = (
f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}"
)
PARAMETERLESS_FUNCTION_CALL = FunctionCall( PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory", name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False), arguments=json.dumps({}, ensure_ascii=False),
...@@ -62,17 +77,38 @@ COMPLEX_FUNCTION_JSON = json.dumps( ...@@ -62,17 +77,38 @@ COMPLEX_FUNCTION_JSON = json.dumps(
}, },
ensure_ascii=False, ensure_ascii=False,
) )
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}"
)
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}"
COMPLEX_FUNCTION_CALL = FunctionCall( COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory", name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False), arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
) )
CONTENT_TEXT = "I'll check that for you."
MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}"
MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}"
@pytest.fixture(name="gigachat_tokenizer")
def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike):
default_tokenizer.add_tokens(
[
MSG_SEP_TOKEN,
ROLE_SEP_TOKEN,
TOOL_HEADER_GIGACHAT31,
EOS_TOKEN,
]
)
return default_tokenizer
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer gigachat_tokenizer
) )
model_output = "How can I help you today?" model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction( content, tool_calls = run_tool_extraction(
...@@ -85,45 +121,143 @@ def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): ...@@ -85,45 +121,143 @@ def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
TEST_CASES = [ TEST_CASES = [
pytest.param( pytest.param(
True, True,
SIMPLE_FUNCTION_OUTPUT, SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL], [SIMPLE_FUNCTION_CALL],
None, None,
id="simple_streaming", id="simple_streaming_gigachat3",
), ),
pytest.param( pytest.param(
False, False,
SIMPLE_FUNCTION_OUTPUT, SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL], [SIMPLE_FUNCTION_CALL],
None, None,
id="simple_nonstreaming", id="simple_nonstreaming_gigachat3",
), ),
pytest.param( pytest.param(
True, True,
PARAMETERLESS_FUNCTION_OUTPUT, PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
[PARAMETERLESS_FUNCTION_CALL], [PARAMETERLESS_FUNCTION_CALL],
None, None,
id="parameterless_streaming", id="parameterless_streaming_gigachat3",
), ),
pytest.param( pytest.param(
False, False,
PARAMETERLESS_FUNCTION_OUTPUT, PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
[PARAMETERLESS_FUNCTION_CALL], [PARAMETERLESS_FUNCTION_CALL],
None, None,
id="parameterless_nonstreaming", id="parameterless_nonstreaming_gigachat3",
), ),
pytest.param( pytest.param(
True, True,
COMPLEX_FUNCTION_OUTPUT, COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
[COMPLEX_FUNCTION_CALL], [COMPLEX_FUNCTION_CALL],
None, None,
id="complex_streaming", id="complex_streaming_gigachat3",
), ),
pytest.param( pytest.param(
False, False,
COMPLEX_FUNCTION_OUTPUT, COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
[COMPLEX_FUNCTION_CALL], [COMPLEX_FUNCTION_CALL],
None, None,
id="complex_nonstreaming", id="complex_nonstreaming_gigachat3",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_gigachat3",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_gigachat3",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_with_eos_gigachat3",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_with_eos_gigachat3",
),
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming_gigachat31",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming_gigachat31",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming_gigachat31",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming_gigachat31",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming_gigachat31",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming_gigachat31",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_gigachat31",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_gigachat31",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_with_eos_gigachat31",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_with_eos_gigachat31",
), ),
] ]
...@@ -136,14 +270,16 @@ def test_tool_call( ...@@ -136,14 +270,16 @@ def test_tool_call(
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
expected_content: str | None, expected_content: str | None,
default_tokenizer: TokenizerLike, gigachat_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer gigachat_tokenizer
) )
content, tool_calls = run_tool_extraction( content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming tool_parser, model_output, streaming=streaming
) )
if content == "":
content = None
assert content == expected_content assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls) assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls): for actual, expected in zip(tool_calls, expected_tool_calls):
...@@ -154,15 +290,46 @@ def test_tool_call( ...@@ -154,15 +290,46 @@ def test_tool_call(
assert actual_args == expected_args assert actual_args == expected_args
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike): @pytest.mark.parametrize(
"model_output_deltas",
[
pytest.param(
[
CONTENT_TEXT[:3],
CONTENT_TEXT[3:5],
CONTENT_TEXT[5:],
MSG_SEP_TOKEN,
TOOL_HEADER_GIGACHAT3,
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:-1],
COMPLEX_FUNCTION_JSON[-1],
],
id="gigachat3",
),
pytest.param(
[
CONTENT_TEXT[:3],
CONTENT_TEXT[3:5],
CONTENT_TEXT[5:],
TOOL_HEADER_GIGACHAT31,
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:-1],
COMPLEX_FUNCTION_JSON[-1],
],
id="gigachat31",
),
],
)
def test_streaming_tool_call_with_large_steps(
model_output_deltas: list[str],
gigachat_tokenizer: TokenizerLike,
):
"""
Test that the closing braces are streamed correctly.
"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer gigachat_tokenizer
) )
model_output_deltas = [
"function call",
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:],
]
reconstructor = run_tool_extraction_streaming( reconstructor = run_tool_extraction_streaming(
tool_parser, tool_parser,
model_output_deltas, model_output_deltas,
......
...@@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser ...@@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
REGEX_FUNCTION_CALL = re.compile( REGEX_FUNCTION_CALL = re.compile(
r"function call(?:<\|role_sep\|>\n)?(\{.*)", r"(?:function call<\|role_sep\|>\n|<\|function_call\|>)(.*)",
re.DOTALL,
)
REGEX_CONTENT_PATTERN = re.compile(
r"^(.*?)(?:<\|message_sep\|>|<\|function_call\|>)",
re.DOTALL, re.DOTALL,
) )
...@@ -47,42 +52,58 @@ class GigaChat3ToolParser(ToolParser): ...@@ -47,42 +52,58 @@ class GigaChat3ToolParser(ToolParser):
self.tool_name_sent: bool = False self.tool_name_sent: bool = False
self.tool_id: str | None = None self.tool_id: str | None = None
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
self.content_buffer: str = "" self.end_content: bool = False
self.trigger_start = "function call{" self.streamed_args_for_tool: list[str] = []
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request
def extract_tool_calls( def extract_tool_calls(
self, self,
model_output: str, model_output: str,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
match = REGEX_FUNCTION_CALL.search(model_output) function_call = None
if not match: content = None
return ExtractedToolCallInformation( if model_output.rstrip().endswith("</s>"):
tools_called=False, model_output = model_output[: model_output.rfind("</s>")]
tool_calls=[], m_func = REGEX_FUNCTION_CALL.search(model_output)
content=model_output, if m_func:
)
json_candidate = match.group(1).strip()
try: try:
data = json.loads(json_candidate) function_call = json.loads(m_func.group(1), strict=False)
if (
isinstance(function_call, dict)
and "name" in function_call
and "arguments" in function_call
):
if not isinstance(function_call["arguments"], dict):
function_call = None
else:
function_call = None
except json.JSONDecodeError: except json.JSONDecodeError:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tools_called=False,
tool_calls=[], tool_calls=[],
content=model_output, content=model_output,
) )
if not (isinstance(data, dict) and "name" in data and "arguments" in data): m_content = REGEX_CONTENT_PATTERN.search(model_output)
content = m_content.group(1) if m_content else model_output
if not function_call:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tools_called=False,
tool_calls=[], tool_calls=[],
content=model_output, content=content if content else None,
) )
name = data["name"] name = function_call["name"]
args = data["arguments"] args = function_call["arguments"]
if not isinstance(args, str): if not isinstance(args, str):
args = json.dumps(args, ensure_ascii=False) args = json.dumps(function_call["arguments"], ensure_ascii=False)
return ExtractedToolCallInformation(
tool_calls = [ tools_called=True,
tool_calls=[
ToolCall( ToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
...@@ -90,14 +111,8 @@ class GigaChat3ToolParser(ToolParser): ...@@ -90,14 +111,8 @@ class GigaChat3ToolParser(ToolParser):
arguments=args, arguments=args,
), ),
) )
] ],
prefix = model_output[: match.start()] content=content if content else None,
content = prefix.rstrip() if prefix and prefix.strip() else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
) )
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
...@@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser): ...@@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser):
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> DeltaMessage | None: ) -> DeltaMessage | None:
content = None
func_name = None func_name = None
cur_args = None cur_args = None
m_func = REGEX_FUNCTION_CALL.search(current_text)
if not self.tool_started: if not self.tool_started:
match = REGEX_FUNCTION_CALL.search(current_text) m_content = REGEX_CONTENT_PATTERN.search(delta_text)
if match: if m_content:
self.tool_started = True content = m_content.group(1)
self.content_buffer = "" self.end_content = True
else: else:
self.content_buffer += delta_text if not self.end_content:
clean_buffer = self.content_buffer.lstrip() content = delta_text
is_prefix = self.trigger_start.startswith(clean_buffer) if m_func:
starts_with_trigger = clean_buffer.startswith(self.trigger_start) self.tool_started = True
if is_prefix or starts_with_trigger: if content:
return None return DeltaMessage(content=content)
else: if not m_func:
flush_text = self.content_buffer
self.content_buffer = ""
return DeltaMessage(content=flush_text)
match = REGEX_FUNCTION_CALL.search(current_text)
if not match:
return None return None
json_tail = match.group(1).strip() json_tail = m_func.group(1).strip()
name_match = NAME_REGEX.search(json_tail) name_match = NAME_REGEX.search(json_tail)
if name_match: if name_match:
func_name = name_match.group(1) func_name = name_match.group(1)
args_match = ARGS_REGEX.search(json_tail) args_match = ARGS_REGEX.search(json_tail)
if args_match: if args_match:
cur_args = args_match.group(1).strip() cur_args = args_match.group(1).strip()
if cur_args.endswith("</s>"):
cur_args = cur_args[: -len("</s>")]
if cur_args.endswith("}"): # last '}' end of json if cur_args.endswith("}"): # last '}' end of json
try: try:
candidate = cur_args[:-1].strip() candidate = cur_args[:-1].strip()
json.loads(candidate) json.loads(candidate, strict=False)
cur_args = candidate cur_args = candidate
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
...@@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser): ...@@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser):
).model_dump(exclude_none=True), ).model_dump(exclude_none=True),
) )
], ],
content=None,
) )
if cur_args is None: if cur_args is None:
return None return None
prev_args = self.prev_tool_call_arr[0].get("arguments", "") prev_args = self.prev_tool_call_arr[0].get("arguments_str", "")
if not prev_args: if not prev_args:
delta_args = cur_args delta_args = cur_args
elif cur_args.startswith(prev_args): elif cur_args.startswith(prev_args):
...@@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser): ...@@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser):
return None return None
if not delta_args: if not delta_args:
return None return None
self.prev_tool_call_arr[0]["arguments"] = cur_args self.prev_tool_call_arr[0]["arguments_str"] = cur_args
try:
args_dict = json.loads(cur_args, strict=False)
self.prev_tool_call_arr[0]["arguments"] = args_dict
except json.JSONDecodeError:
self.prev_tool_call_arr[0]["arguments"] = {}
if len(self.streamed_args_for_tool) <= 0:
self.streamed_args_for_tool.append("")
self.streamed_args_for_tool[0] = cur_args
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
...@@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser): ...@@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser):
).model_dump(exclude_none=True), ).model_dump(exclude_none=True),
) )
], ],
content=None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment