openai_compatible.py 7.55 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
from time import perf_counter
from openai import APIStatusError, BadRequestError, OpenAI, PermissionDeniedError, UnprocessableEntityError
from openai._types import NOT_GIVEN
from openai.types.chat import ChatCompletion
from typing import Any, Dict, List, Optional, Tuple, Union

from evalscope.api.messages import ChatMessage
from evalscope.api.model import ChatCompletionChoice, GenerateConfig, ModelAPI, ModelOutput
from evalscope.api.tool import ToolChoice, ToolInfo
from evalscope.utils import get_logger
from evalscope.utils.argument_utils import get_supported_params
from evalscope.utils.function_utils import retry_call
from .utils.openai import (
    chat_choices_from_openai,
    collect_stream_response,
    model_output_from_openai,
    openai_chat_messages,
    openai_chat_tool_choice,
    openai_chat_tools,
    openai_completion_params,
    openai_handle_bad_request,
)

logger = get_logger()


class OpenAICompatibleAPI(ModelAPI):

    def __init__(
        self,
        model_name: str,
        base_url: Optional[str] = None,
        api_key: Optional[str] = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:

        super().__init__(
            model_name=model_name,
            base_url=base_url,
            api_key=api_key,
            config=config,
        )

        # use service prefix to lookup api_key
        self.api_key = api_key or os.environ.get('EVALSCOPE_API_KEY', None)
        assert self.api_key, f'API key for {model_name} not found'

        # use service prefix to lookup base_url
        self.base_url = base_url or os.environ.get('EVALSCOPE_BASE_URL', None)
        assert self.base_url, f'Base URL for {model_name} not found'

        # remove trailing slash from base_url
        self.base_url = self.base_url.rstrip('/').removesuffix('/chat/completions')

        # create http client
        self.client = OpenAI(
            api_key=self.api_key,
            base_url=self.base_url,
            **model_args,
        )

    def generate(
        self,
        input: List[ChatMessage],
        tools: List[ToolInfo],
        tool_choice: ToolChoice,
        config: GenerateConfig,
    ) -> ModelOutput:
        # setup request and response for ModelCall
        request: Dict[str, Any] = {}
        response: Dict[str, Any] = {}

        tools, tool_choice, config = self.resolve_tools(tools, tool_choice, config)

        # get completion params (slice off service from model name)
        completion_params = self.completion_params(
            config=config,
            tools=len(tools) > 0,
        )

        request = dict(
            messages=openai_chat_messages(input),
            tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
            tool_choice=openai_chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
            **completion_params,
        )

        self.validate_request_params(request)

        try:
            # generate completion and save response for model call
            request_start = perf_counter()
            completion = retry_call(
                self.client.chat.completions.create,
                retries=config.retries,
                sleep_interval=config.retry_interval,
                **request
            )
            # handle streaming response
            ttft = None
            is_stream_response = not isinstance(completion, ChatCompletion)
            if is_stream_response:
                collected_chunks = []
                for chunk in completion:
                    collected_chunks.append(chunk)
                    # TTFT should reflect first generated token/content chunk, not just any chunk.
                    # Different OpenAI-compatible servers may return delta as object or dict.
                    if ttft is None and self._chunk_has_generation_payload(chunk):
                        ttft = perf_counter() - request_start
                completion = collect_stream_response(collected_chunks)

            response = completion.model_dump()
            self.on_response(response)

            # return output and call
            choices = self.chat_choices_from_completion(completion, tools)
            model_output = model_output_from_openai(completion, choices)
            if ttft is not None:
                model_output.metadata = model_output.metadata or {}
                model_output.metadata['ttft'] = ttft
                model_output.metadata['ttft_source'] = 'first_content_stream_chunk'
            return model_output

        except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
            return self.handle_bad_request(ex)
        except ValueError as ex:
            logger.error(f'Model [{self.model_name}] returned an invalid response: {ex}')
            raise

    def resolve_tools(self, tools: List[ToolInfo], tool_choice: ToolChoice,
                      config: GenerateConfig) -> Tuple[List[ToolInfo], ToolChoice, GenerateConfig]:
        """Provides an opportunity for concrete classes to customize tool resolution."""
        return tools, tool_choice, config

    def completion_params(self, config: GenerateConfig, tools: bool) -> Dict[str, Any]:
        return openai_completion_params(
            model=self.model_name,
            config=config,
            tools=tools,
        )

    def validate_request_params(self, params: Dict[str, Any]):
        """Hook for subclasses to do custom request parameter validation."""
        # Cache supported params to avoid repeated calls to inspect.signature.
        if not hasattr(self, '_valid_params'):
            self._valid_params = get_supported_params(self.client.chat.completions.create)

        # Move unsupported parameters to extra_body.
        extra_body = params.get('extra_body', {})
        for key in list(params.keys()):
            if key not in self._valid_params:
                extra_body[key] = params.pop(key)

        if extra_body:
            params['extra_body'] = extra_body

    def on_response(self, response: Dict[str, Any]) -> None:
        """Hook for subclasses to do custom response handling."""
        pass

    def chat_choices_from_completion(self, completion: ChatCompletion,
                                     tools: List[ToolInfo]) -> List[ChatCompletionChoice]:
        """Hook for subclasses to do custom chat choice processing."""
        return chat_choices_from_openai(completion, tools)

    def handle_bad_request(self, ex: APIStatusError) -> Union[ModelOutput, Exception]:
        """Hook for subclasses to do bad request handling"""
        return openai_handle_bad_request(self.model_name, ex)

    @staticmethod
    def _chunk_has_generation_payload(chunk: Any) -> bool:
        """Return True when stream chunk carries actual generated payload."""
        choices = getattr(chunk, 'choices', None) or []
        for choice in choices:
            delta = getattr(choice, 'delta', None)
            if delta is None:
                continue

            if isinstance(delta, dict):
                content = delta.get('content')
                reasoning = delta.get('reasoning_content') or delta.get('reasoning')
                tool_calls = delta.get('tool_calls')
            else:
                content = getattr(delta, 'content', None)
                reasoning = getattr(delta, 'reasoning_content', None) or getattr(delta, 'reasoning', None)
                tool_calls = getattr(delta, 'tool_calls', None)

            if content not in (None, '', []):
                return True
            if reasoning not in (None, '', []):
                return True
            if tool_calls:
                return True
        return False