tokenizer_group.py 5.14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5

6
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
7
from vllm.lora.request import LoRARequest
8
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
9
                                               get_lora_tokenizer,
10
11
                                               get_lora_tokenizer_async,
                                               get_tokenizer)
12
13
from vllm.utils import LRUCache

14

15
class TokenizerGroup:
16
17
18
19
20
21
22
23
24
    """A group of tokenizers that can be used for LoRA adapters."""

    def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
                 max_input_length: Optional[int], **tokenizer_config):
        self.tokenizer_id = tokenizer_id
        self.tokenizer_config = tokenizer_config
        self.enable_lora = enable_lora
        self.max_input_length = max_input_length
        self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
25
        max_loras = tokenizer_config.get("max_loras", 0)
26
        self.lora_tokenizers = LRUCache[int, AnyTokenizer](
27
            capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
28
29
30
31
32
33
34

    def get_max_input_len(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
        """Get the maximum input length for the LoRA request."""
        return self.max_input_length

35
    def _raise_if_input_too_long(self,
36
                                 encoded_tokens: list[int],
37
38
39
40
41
42
43
44
45
46
                                 lora_request: Optional[LoRARequest] = None):
        input_length = len(encoded_tokens)
        if lora_request:
            max_input_length = (lora_request.long_lora_max_len
                                or self.max_input_length)
        else:
            max_input_length = self.max_input_length
        if max_input_length is not None and input_length > max_input_length:
            raise ValueError("Input too long.", input_length, max_input_length)

47
48
    def encode(self,
               prompt: str,
49
50
               max_length: Optional[int] = None,
               truncation: Optional[bool] = None,
51
               lora_request: Optional[LoRARequest] = None,
52
               add_special_tokens: Optional[bool] = None) -> list[int]:
53

54
        tokenizer = self.get_lora_tokenizer(lora_request)
55
56
        ret = encode_tokens(tokenizer,
                            prompt,
57
58
                            max_length=max_length,
                            truncation=truncation,
59
                            add_special_tokens=add_special_tokens)
60
61
        self._raise_if_input_too_long(ret, lora_request)
        return ret
62
63
64
65

    async def encode_async(
            self,
            prompt: str,
66
67
            max_length: Optional[int] = None,
            truncation: Optional[bool] = None,
68
            lora_request: Optional[LoRARequest] = None,
69
            add_special_tokens: Optional[bool] = None) -> list[int]:
70
        tokenizer = await self.get_lora_tokenizer_async(lora_request)
71
72
        ret = encode_tokens(tokenizer,
                            prompt,
73
74
                            max_length=max_length,
                            truncation=truncation,
75
                            add_special_tokens=add_special_tokens)
76
77
        self._raise_if_input_too_long(ret, lora_request)
        return ret
78
79

    def get_lora_tokenizer(
80
81
82
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
83
84
85
86
87
88
89
90
        if not lora_request or not self.enable_lora:
            return self.tokenizer
        if lora_request.lora_int_id not in self.lora_tokenizers:
            tokenizer = (get_lora_tokenizer(
                lora_request, **self.tokenizer_config) or self.tokenizer)
            self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
            return tokenizer
        else:
91
            return self.lora_tokenizers[lora_request.lora_int_id]
92
93

    async def get_lora_tokenizer_async(
94
95
96
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
97
98
99
100
101
102
103
104
        if not lora_request or not self.enable_lora:
            return self.tokenizer
        if lora_request.lora_int_id not in self.lora_tokenizers:
            tokenizer = (await get_lora_tokenizer_async(
                lora_request, **self.tokenizer_config) or self.tokenizer)
            self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
            return tokenizer
        else:
105
            return self.lora_tokenizers[lora_request.lora_int_id]
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120


def init_tokenizer_from_configs(model_config: ModelConfig,
                                scheduler_config: SchedulerConfig,
                                lora_config: Optional[LoRAConfig]):
    return TokenizerGroup(
        tokenizer_id=model_config.tokenizer,
        enable_lora=bool(lora_config),
        max_num_seqs=scheduler_config.max_num_seqs,
        max_loras=lora_config.max_loras if lora_config else 0,
        max_input_length=None,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=model_config.truncation_side)