longcat_tool_parser.py 1.26 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import regex as re

6
from vllm.tokenizers import TokenizerLike
7
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
8
9
10


class LongcatFlashToolParser(Hermes2ProToolParser):
11
    def __init__(self, tokenizer: TokenizerLike):
12
13
14
15
16
17
18
        super().__init__(tokenizer)

        self.tool_call_start_token: str = "<longcat_tool_call>"
        self.tool_call_end_token: str = "</longcat_tool_call>"

        self.tool_call_regex = re.compile(
            r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)",
19
20
            re.DOTALL,
        )
21
22

        self.tool_call_start_token_ids = self.model_tokenizer.encode(
23
24
            self.tool_call_start_token, add_special_tokens=False
        )
25
        self.tool_call_end_token_ids = self.model_tokenizer.encode(
26
27
            self.tool_call_end_token, add_special_tokens=False
        )
28
29
30
31
32
33
34
35
36
37

        self.tool_call_start_token_array = [
            self.model_tokenizer.decode([token_id])
            for token_id in self.tool_call_start_token_ids
        ]

        self.tool_call_end_token_array = [
            self.model_tokenizer.decode([token_id])
            for token_id in self.tool_call_end_token_ids
        ]