serving_tokenization.py 3.54 KB
Newer Older
1
2
3
4
from typing import List, Optional

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
5
6
7
from vllm.entrypoints.chat_utils import (ConversationMessage,
                                         load_chat_template,
                                         parse_chat_message_content)
8
9
10
11
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
12
13
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
14
15
16
17
18
19
20
21


class OpenAIServingTokenization(OpenAIServing):

    def __init__(self,
                 engine: AsyncLLMEngine,
                 model_config: ModelConfig,
                 served_model_names: List[str],
22
                 lora_modules: Optional[List[LoRAModulePath]] = None,
23
24
25
26
                 chat_template: Optional[str] = None):
        super().__init__(engine=engine,
                         model_config=model_config,
                         served_model_names=served_model_names,
27
                         lora_modules=lora_modules)
28

29
30
        # If this is None we use the tokenizer's default chat template
        self.chat_template = load_chat_template(chat_template)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    async def create_tokenize(self,
                              request: TokenizeRequest) -> TokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        if not (request.prompt or request.messages):
            return self.create_error_response(
                "Either `prompt` or `messages` should be provided.")

        if (request.prompt and request.messages):
            return self.create_error_response(
                "Only one of `prompt` or `messages` should be provided.")

46
47
        _, lora_request = self._maybe_get_adapter(request)
        tokenizer = await self.engine.get_tokenizer(lora_request)
48
49
50
51
        if request.messages:
            conversation: List[ConversationMessage] = []

            for message in request.messages:
52
53
54
                result = parse_chat_message_content(message, self.model_config,
                                                    tokenizer)
                conversation.extend(result.messages)
55

56
            request.prompt = tokenizer.apply_chat_template(
57
58
                add_generation_prompt=request.add_generation_prompt,
                conversation=conversation,
59
60
                tokenize=False,
                chat_template=self.chat_template)
61

62
        (input_ids, input_text) = await self._validate_prompt_and_tokenize(
63
            request,
64
            tokenizer,
65
66
67
68
69
70
71
72
73
74
75
76
77
            prompt=request.prompt,
            add_special_tokens=request.add_special_tokens)

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
            self, request: DetokenizeRequest) -> DetokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

78
79
80
81
        _, lora_request = self._maybe_get_adapter(request)
        tokenizer = await self.engine.get_tokenizer(lora_request)
        (input_ids, input_text) = await self._validate_prompt_and_tokenize(
            request, tokenizer, prompt_ids=request.tokens)
82
83

        return DetokenizeResponse(prompt=input_text)