tokenizer_group.py 5.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5

6
7
from typing_extensions import assert_never

8
9
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config.lora import LoRAConfig
10
from vllm.lora.request import LoRARequest
11
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
12
                                               get_lora_tokenizer,
13
14
                                               get_lora_tokenizer_async,
                                               get_tokenizer)
15
16
from vllm.utils import LRUCache

17

18
class TokenizerGroup:
19
20
21
22
23
24
25
26
    """A group of tokenizers that can be used for LoRA adapters."""

    def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
                 max_input_length: Optional[int], **tokenizer_config):
        self.tokenizer_id = tokenizer_id
        self.tokenizer_config = tokenizer_config
        self.enable_lora = enable_lora
        self.max_input_length = max_input_length
27
        self.truncation_side = tokenizer_config.get("truncation_side", "left")
28
        self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
29
        max_loras = tokenizer_config.get("max_loras", 0)
30
        self.lora_tokenizers = LRUCache[int, AnyTokenizer](
31
            capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
32
33
34
35
36
37
38

    def get_max_input_len(self,
                          lora_request: Optional[LoRARequest] = None
                          ) -> Optional[int]:
        """Get the maximum input length for the LoRA request."""
        return self.max_input_length

39
    def _raise_if_input_too_long(self,
40
                                 encoded_tokens: list[int],
41
42
43
44
45
46
47
48
49
50
                                 lora_request: Optional[LoRARequest] = None):
        input_length = len(encoded_tokens)
        if lora_request:
            max_input_length = (lora_request.long_lora_max_len
                                or self.max_input_length)
        else:
            max_input_length = self.max_input_length
        if max_input_length is not None and input_length > max_input_length:
            raise ValueError("Input too long.", input_length, max_input_length)

51
52
    def encode(self,
               prompt: str,
53
54
               max_length: Optional[int] = None,
               truncation: Optional[bool] = None,
55
               lora_request: Optional[LoRARequest] = None,
56
               add_special_tokens: Optional[bool] = None) -> list[int]:
57

58
        tokenizer = self.get_lora_tokenizer(lora_request)
59
60
        ret = encode_tokens(tokenizer,
                            prompt,
61
62
                            max_length=max_length,
                            truncation=truncation,
63
                            add_special_tokens=add_special_tokens)
64
65
        self._raise_if_input_too_long(ret, lora_request)
        return ret
66
67
68
69

    async def encode_async(
            self,
            prompt: str,
70
71
            max_length: Optional[int] = None,
            truncation: Optional[bool] = None,
72
            lora_request: Optional[LoRARequest] = None,
73
            add_special_tokens: Optional[bool] = None) -> list[int]:
74
        tokenizer = await self.get_lora_tokenizer_async(lora_request)
75
76
        ret = encode_tokens(tokenizer,
                            prompt,
77
78
                            max_length=max_length,
                            truncation=truncation,
79
                            add_special_tokens=add_special_tokens)
80
81
        self._raise_if_input_too_long(ret, lora_request)
        return ret
82
83

    def get_lora_tokenizer(
84
85
86
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
87
88
89
90
91
92
93
94
        if not lora_request or not self.enable_lora:
            return self.tokenizer
        if lora_request.lora_int_id not in self.lora_tokenizers:
            tokenizer = (get_lora_tokenizer(
                lora_request, **self.tokenizer_config) or self.tokenizer)
            self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
            return tokenizer
        else:
95
            return self.lora_tokenizers[lora_request.lora_int_id]
96
97

    async def get_lora_tokenizer_async(
98
99
100
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
101
102
103
104
105
106
107
108
        if not lora_request or not self.enable_lora:
            return self.tokenizer
        if lora_request.lora_int_id not in self.lora_tokenizers:
            tokenizer = (await get_lora_tokenizer_async(
                lora_request, **self.tokenizer_config) or self.tokenizer)
            self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
            return tokenizer
        else:
109
            return self.lora_tokenizers[lora_request.lora_int_id]
110
111
112
113
114


def init_tokenizer_from_configs(model_config: ModelConfig,
                                scheduler_config: SchedulerConfig,
                                lora_config: Optional[LoRAConfig]):
115
116
117
118
119
120
121
122
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)

123
124
125
126
127
128
129
130
131
    return TokenizerGroup(
        tokenizer_id=model_config.tokenizer,
        enable_lora=bool(lora_config),
        max_num_seqs=scheduler_config.max_num_seqs,
        max_loras=lora_config.max_loras if lora_config else 0,
        max_input_length=None,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
132
        truncation_side=truncation_side)