serving_tokenization.py 2.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import List, Optional

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
                                                load_chat_template,
                                                parse_chat_message_content)
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
from vllm.entrypoints.openai.serving_engine import OpenAIServing


class OpenAIServingTokenization(OpenAIServing):

    def __init__(self,
                 engine: AsyncLLMEngine,
                 model_config: ModelConfig,
                 served_model_names: List[str],
                 chat_template: Optional[str] = None):
        super().__init__(engine=engine,
                         model_config=model_config,
                         served_model_names=served_model_names,
                         lora_modules=None)

        load_chat_template(self, chat_template)

    async def create_tokenize(self,
                              request: TokenizeRequest) -> TokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        if not (request.prompt or request.messages):
            return self.create_error_response(
                "Either `prompt` or `messages` should be provided.")

        if (request.prompt and request.messages):
            return self.create_error_response(
                "Only one of `prompt` or `messages` should be provided.")

        if request.messages:
            conversation: List[ConversationMessage] = []

            for message in request.messages:
                conversation.extend(
                    parse_chat_message_content(self, message).messages)

            request.prompt = self.tokenizer.apply_chat_template(
                add_generation_prompt=request.add_generation_prompt,
                conversation=conversation,
                tokenize=False)

        (input_ids, input_text) = self._validate_prompt_and_tokenize(
            request,
            prompt=request.prompt,
            add_special_tokens=request.add_special_tokens)

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
            self, request: DetokenizeRequest) -> DetokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        (input_ids, input_text) = self._validate_prompt_and_tokenize(
            request, prompt_ids=request.tokens)

        return DetokenizeResponse(prompt=input_text)