inference_interface.py 5.6 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import asyncio
from abc import abstractmethod
from itertools import zip_longest
from typing import Annotated, Any, ClassVar

from pydantic import BaseModel, BeforeValidator, ValidationError

from ..__init__ import GenericGenerationArgs
from ..inference.api import (
    ChatInferenceRequest,
    ChatInferenceResponse,
    GroupedChatInferenceResponse,
    GroupedInferenceResponse,
    InferenceRequest,
    InferenceResponse,
    LLMChatMessage,
)
from ..inference.chat_templates import ConversationTemplate


# Used when generating n resposnes for a single prompt
def grouper(iterable, n, fillvalue=None):
    """Fold an iterable into a list of lists of size n."""
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)


class InferenceInterface(BaseModel):
    """Inference interface that for base language models."""

    class Config:
        arbitrary_types_allowed = True

    supports_n: ClassVar[bool] = False

    def prepare_request(
        self, prompts: list[str], generation_args: GenericGenerationArgs
    ) -> InferenceRequest:
        assert all(isinstance(p, str) for p in prompts), "Prompt must be a list of strings"
        return InferenceRequest(prompt=prompts, generation_args=generation_args)

    @abstractmethod
    async def base_generate(self, request: InferenceRequest) -> list[InferenceResponse]:
        raise NotImplementedError(
            "Direct Inference Classes must implement the base_generate method."
        )

    def duplicate_requests(self, request: InferenceRequest, n: int) -> list[InferenceRequest]:
        return request.model_copy(update={'prompt': request.prompt * n})

    def fold_responses(
        self, responses: list[InferenceResponse], n: int
    ) -> list[GroupedInferenceResponse]:
        return [GroupedInferenceResponse(responses=x) for x in list(grouper(responses, n))]

    async def agenerate(
        self, request: InferenceRequest
    ) -> list[InferenceResponse] | list[GroupedInferenceResponse]:
        if not self.supports_n and request.n is not None:
            request = self.duplicate_requests(request, request.n)

        generations = await self.base_generate(request)

        if request.n is not None:
            if self.supports_n:
                assert (
                    len(generations) == len(request.prompt) * request.n
                ), f"Number of generations ({len(generations)}) does not match number of prompts ({len(request.prompt)} * {request.n})."
            else:
                assert len(generations) == len(
                    request.prompt
                ), f"Number of generations ({len(generations)}) does not match number of prompts ({len(request.prompt)})."
            generations = self.fold_responses(generations, request.n)

        return generations

    def generate(
        self, request: InferenceRequest
    ) -> list[InferenceResponse] | list[GroupedInferenceResponse]:
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            return asyncio.run(self.agenerate(request))
        else:
            return loop.run_until_complete(self.agenerate(request))


def ensure_template(value: Any) -> ConversationTemplate:
    if isinstance(value, ConversationTemplate):
        return value
    elif isinstance(value, str):
        return ConversationTemplate.from_string(value)
    else:
        raise ValidationError(f"Invalid conversation template: {value}")


class ChatInferenceInterface(InferenceInterface):
    """Inference interface for chat models."""

    conversation_template: Annotated[ConversationTemplate, BeforeValidator(ensure_template)]

    def prepare_request(
        self, prompts: list[str | list[LLMChatMessage]], generation_args: GenericGenerationArgs
    ) -> ChatInferenceRequest:
        prompt = [
            [LLMChatMessage(role='user', content=p)] if isinstance(p, str) else p for p in prompts
        ]
        return ChatInferenceRequest(prompt=prompt, generation_args=generation_args)

    async def base_generate(self, request: ChatInferenceRequest) -> list[ChatInferenceResponse]:
        base_generate_results = await super().base_generate(
            InferenceRequest(
                prompt=[self.conversation_template.format(messages) for messages in request.prompt],
                generation_args=request.generation_args,
            )
        )
        chat_message_results = self.conversation_template.parse_response(base_generate_results)
        return [
            ChatInferenceResponse(
                response=chat_message, **response.model_dump(exclude={'response'})
            )
            for chat_message, response in zip(chat_message_results, base_generate_results)
        ]

    def generate(
        self, request: ChatInferenceRequest
    ) -> list[ChatInferenceResponse] | list[GroupedChatInferenceResponse]:
        return super().generate(request)

    async def agenerate(
        self, request: ChatInferenceRequest
    ) -> list[ChatInferenceResponse] | list[GroupedChatInferenceResponse]:
        return await super().agenerate(request)


class ReturnsRaw(InferenceInterface):
    """Mix-In for interface that supports returning complete string fed to the LLM."""

    # TODO: Should this be a mix-in or a class variable?


class ReturnsTokens(InferenceInterface):
    """Mix-In for interface that supports returning the complete list of tokens fed to the LLM."""

    # TODO: Should this be a mix-in or a class variable?


class ReturnsLogProbs(ReturnsTokens):
    """Mix-In for interface that supports returning the logprobs for a set of tokens."""

    # TODO: Should this be a mix-in or a class variable?