tokenizer_base.py 3.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# SPDX-License-Identifier: Apache-2.0

import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

if TYPE_CHECKING:
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


class TokenizerBase(ABC):

    @property
    @abstractmethod
    def all_special_tokens_extended(self) -> List[str]:
        raise NotImplementedError()

    @property
    @abstractmethod
    def all_special_tokens(self) -> List[str]:
        raise NotImplementedError()

    @property
    @abstractmethod
    def all_special_ids(self) -> List[int]:
        raise NotImplementedError()

    @property
    @abstractmethod
    def bos_token_id(self) -> int:
        raise NotImplementedError()

    @property
    @abstractmethod
    def eos_token_id(self) -> int:
        raise NotImplementedError()

    @property
    @abstractmethod
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    @abstractmethod
    def pad_token(self) -> str:
        raise NotImplementedError()

    @property
    @abstractmethod
    def is_fast(self) -> bool:
        raise NotImplementedError()

    @property
    @abstractmethod
    def vocab_size(self) -> int:
        raise NotImplementedError()

    @property
    @abstractmethod
    def max_token_id(self) -> int:
        raise NotImplementedError()

    def __len__(self) -> int:
        return self.vocab_size

    @abstractmethod
    def __call__(
        self,
        text: Union[str, List[str], List[int]],
        text_pair: Optional[str] = None,
        add_special_tokens: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ):
        raise NotImplementedError()

    @abstractmethod
    def get_vocab(self) -> Dict[str, int]:
        raise NotImplementedError()

    @abstractmethod
    def get_added_vocab(self) -> Dict[str, int]:
        raise NotImplementedError()

    @abstractmethod
    def encode_one(
        self,
        text: str,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    def encode(self,
               text: str,
               add_special_tokens: Optional[bool] = None) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    def apply_chat_template(self,
                            messages: List["ChatCompletionMessageParam"],
                            tools: Optional[List[Dict[str, Any]]] = None,
                            **kwargs) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        raise NotImplementedError()

    @abstractmethod
    def decode(self,
               ids: Union[List[int], int],
               skip_special_tokens: bool = True) -> str:
        raise NotImplementedError()

    @abstractmethod
    def convert_ids_to_tokens(
        self,
        ids: List[int],
        skip_special_tokens: bool = True,
    ) -> List[str]:
        raise NotImplementedError()


class TokenizerRegistry:
    # Tokenizer name -> (tokenizer module, tokenizer class)
    REGISTRY: Dict[str, Tuple[str, str]] = {}

    @staticmethod
    def register(name: str, module: str, class_name: str) -> None:
        TokenizerRegistry.REGISTRY[name] = (module, class_name)

    @staticmethod
    def get_tokenizer(
        tokenizer_name: str,
        *args,
        **kwargs,
    ) -> TokenizerBase:
        tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
        if tokenizer_cls is None:
            raise ValueError(f"Tokenizer {tokenizer_name} not found.")

        tokenizer_module = importlib.import_module(tokenizer_cls[0])
        class_ = getattr(tokenizer_module, tokenizer_cls[1])
        return class_.from_pretrained(*args, **kwargs)