adapter.py 15.4 KB
Newer Older
1
"""Conversion between OpenAI APIs and native SRT APIs"""
Liangsheng Yin's avatar
Liangsheng Yin committed
2

3
import asyncio
4
5
import json
import os
6
from http import HTTPStatus
7

8
from fastapi import Request
9
from fastapi.responses import JSONResponse, StreamingResponse
10
11
12
13
14
15
16
17
18

from sglang.srt.conversation import (
    Conversation,
    SeparatorStyle,
    chat_template_exists,
    generate_chat_conv,
    register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
Mingyi's avatar
Mingyi committed
19
from sglang.srt.openai_api.protocol import (
20
21
22
23
24
25
26
27
28
29
30
31
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    DeltaMessage,
32
    ErrorResponse,
33
34
35
36
37
38
    LogProbs,
    UsageInfo,
)

chat_template_name = None

Liangsheng Yin's avatar
Liangsheng Yin committed
39

40
41
42
def create_error_response(
    message: str,
    err_type: str = "BadRequestError",
43
44
45
46
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
    return JSONResponse(content=error.model_dump(), status_code=error.code)
47
48
49
50
51


def create_streaming_error_response(
    message: str,
    err_type: str = "BadRequestError",
52
53
54
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
55
56
57
58
    json_str = json.dumps({"error": error.model_dump()})
    return json_str


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def load_chat_template_for_openai_api(chat_template_arg):
    global chat_template_name

    print(f"Use chat template: {chat_template_arg}")
    if not chat_template_exists(chat_template_arg):
        if not os.path.exists(chat_template_arg):
            raise RuntimeError(
                f"Chat template {chat_template_arg} is not a built-in template name "
                "or a valid chat template file path."
            )
        with open(chat_template_arg, "r") as filep:
            template = json.load(filep)
            try:
                sep_style = SeparatorStyle[template["sep_style"]]
            except KeyError:
                raise ValueError(
                    f"Unknown separator style: {template['sep_style']}"
                ) from None
            register_conv_template(
                Conversation(
                    name=template["name"],
                    system_template=template["system"] + "\n{system_message}",
                    system_message=template.get("system_message", ""),
                    roles=(template["user"], template["assistant"]),
                    sep_style=sep_style,
                    sep=template.get("sep", "\n"),
                    stop_str=template["stop_str"],
                ),
                override=True,
            )
        chat_template_name = template["name"]
    else:
        chat_template_name = chat_template_arg


async def v1_completions(tokenizer_manager, raw_request: Request):
    request_json = await raw_request.json()
    request = CompletionRequest(**request_json)

    adapted_request = GenerateReqInput(
        text=request.prompt,
        sampling_params={
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": request.stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
            "regex": request.regex,
108
            "n": request.n,
Mingyi's avatar
Mingyi committed
109
            "ignore_eos": request.ignore_eos,
110
111
112
113
114
115
116
117
118
119
120
121
        },
        return_logprob=request.logprobs is not None and request.logprobs > 0,
        top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
        return_text_in_logprobs=True,
        stream=request.stream,
    )

    if adapted_request.stream:

        async def generate_stream_resp():
            stream_buffer = ""
            n_prev_token = 0
122
123
            try:
                async for content in tokenizer_manager.generate_request(
124
125
                    adapted_request, raw_request
                ):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                    text = content["text"]
                    prompt_tokens = content["meta_info"]["prompt_tokens"]
                    completion_tokens = content["meta_info"]["completion_tokens"]

                    if not stream_buffer:  # The first chunk
                        if request.echo:
                            # Prepend prompt in response text.
                            text = request.prompt + text

                    if request.logprobs:
                        # The first chunk and echo is enabled.
                        if not stream_buffer and request.echo:
                            prefill_token_logprobs = content["meta_info"][
                                "prefill_token_logprobs"
                            ]
                            prefill_top_logprobs = content["meta_info"][
                                "prefill_top_logprobs"
                            ]
                        else:
                            prefill_token_logprobs = None
                            prefill_top_logprobs = None

                        logprobs = to_openai_style_logprobs(
                            prefill_token_logprobs=prefill_token_logprobs,
                            prefill_top_logprobs=prefill_top_logprobs,
                            decode_token_logprobs=content["meta_info"][
                                "decode_token_logprobs"
                            ][n_prev_token:],
154
155
156
                            decode_top_logprobs=content["meta_info"][
                                "decode_top_logprobs"
                            ][n_prev_token:],
157
158
                        )

159
160
161
                        n_prev_token = len(
                            content["meta_info"]["decode_token_logprobs"]
                        )
162
                    else:
163
                        logprobs = None
164

165
                    delta = text[len(stream_buffer) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
166
                    stream_buffer = stream_buffer + delta
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                    choice_data = CompletionResponseStreamChoice(
                        index=0,
                        text=delta,
                        logprobs=logprobs,
                        finish_reason=content["meta_info"]["finish_reason"],
                    )
                    chunk = CompletionStreamResponse(
                        id=content["meta_info"]["id"],
                        object="text_completion",
                        choices=[choice_data],
                        model=request.model,
                        usage=UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        ),
                    )
                    yield f"data: {chunk.model_dump_json()}\n\n"
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
188
189
            yield "data: [DONE]\n\n"

190
191
192
193
194
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(adapted_request),
        )
195
196

    # Non-streaming response.
197
198
    try:
        ret = await tokenizer_manager.generate_request(
199
200
            adapted_request, raw_request
        ).__anext__()
201
202
    except ValueError as e:
        return create_error_response(str(e))
203

204
205
206
207
208
209
    if not isinstance(ret, list):
        ret = [ret]
    choices = []

    for idx, ret_item in enumerate(ret):
        text = ret_item["text"]
210
211

        if request.echo:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            text = request.prompt + text

        if request.logprobs:
            if request.echo:
                prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
                prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
            else:
                prefill_token_logprobs = None
                prefill_top_logprobs = None

            logprobs = to_openai_style_logprobs(
                prefill_token_logprobs=prefill_token_logprobs,
                prefill_top_logprobs=prefill_top_logprobs,
                decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
                decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
            )
228
        else:
229
230
231
232
233
234
235
            logprobs = None

        choice_data = CompletionResponseChoice(
            index=idx,
            text=text,
            logprobs=logprobs,
            finish_reason=ret_item["meta_info"]["finish_reason"],
236
237
        )

238
239
        choices.append(choice_data)

240
    response = CompletionResponse(
241
        id=ret[0]["meta_info"]["id"],
242
        model=request.model,
243
        choices=choices,
244
        usage=UsageInfo(
245
246
247
248
249
250
            prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
            completion_tokens=sum(
                item["meta_info"]["completion_tokens"] for item in ret
            ),
            total_tokens=ret[0]["meta_info"]["prompt_tokens"]
            + sum(item["meta_info"]["completion_tokens"] for item in ret),
251
252
        ),
    )
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    return response


async def v1_chat_completions(tokenizer_manager, raw_request: Request):
    request_json = await raw_request.json()
    request = ChatCompletionRequest(**request_json)

    # Prep the data needed for the underlying GenerateReqInput:
    #  - prompt: The full prompt string.
    #  - stop: Custom stop tokens.
    #  - image_data: None or a list of image strings (URLs or base64 strings).
    #    None skips any image processing in GenerateReqInput.
    if not isinstance(request.messages, str):
        # Apply chat template and its stop strings.
        if chat_template_name is None:
            prompt = tokenizer_manager.tokenizer.apply_chat_template(
                request.messages, tokenize=False, add_generation_prompt=True
            )
            stop = request.stop
            image_data = None
        else:
            conv = generate_chat_conv(request, chat_template_name)
            prompt = conv.get_prompt()
            image_data = conv.image_data
            stop = conv.stop_str or []
            if request.stop:
                if isinstance(request.stop, str):
                    stop.append(request.stop)
                else:
                    stop.extend(request.stop)
    else:
        # Use the raw prompt and stop strings if the messages is already a string.
        prompt = request.messages
        stop = request.stop
        image_data = None

    adapted_request = GenerateReqInput(
        text=prompt,
        image_data=image_data,
        sampling_params={
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
            "regex": request.regex,
301
            "n": request.n,
302
303
304
305
306
307
308
309
310
311
        },
        stream=request.stream,
    )

    if adapted_request.stream:

        async def generate_stream_resp():
            is_first = True

            stream_buffer = ""
312
            try:
313
314
315
                async for content in tokenizer_manager.generate_request(
                    adapted_request, raw_request
                ):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                    if is_first:
                        # First chunk with role
                        is_first = False
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=0,
                            delta=DeltaMessage(role="assistant"),
                            finish_reason=content["meta_info"]["finish_reason"],
                        )
                        chunk = ChatCompletionStreamResponse(
                            id=content["meta_info"]["id"],
                            choices=[choice_data],
                            model=request.model,
                        )
                        yield f"data: {chunk.model_dump_json()}\n\n"

                    text = content["text"]
                    delta = text[len(stream_buffer) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
333
                    stream_buffer = stream_buffer + delta
334
335
                    choice_data = ChatCompletionResponseStreamChoice(
                        index=0,
336
                        delta=DeltaMessage(content=delta),
337
                        finish_reason=content["meta_info"]["finish_reason"],
338
339
340
341
342
343
                    )
                    chunk = ChatCompletionStreamResponse(
                        id=content["meta_info"]["id"],
                        choices=[choice_data],
                        model=request.model,
                    )
344
345
346
347
                    yield f"data: {chunk.model_dump_json()}\n\n"
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
348
349
            yield "data: [DONE]\n\n"

350
351
352
353
354
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(adapted_request),
        )
355
356

    # Non-streaming response.
357
358
    try:
        ret = await tokenizer_manager.generate_request(
359
360
            adapted_request, raw_request
        ).__anext__()
361
362
363
    except ValueError as e:
        return create_error_response(str(e))

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    if not isinstance(ret, list):
        ret = [ret]
    choices = []
    total_prompt_tokens = 0
    total_completion_tokens = 0

    for idx, ret_item in enumerate(ret):
        prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
        completion_tokens = ret_item["meta_info"]["completion_tokens"]

        choice_data = ChatCompletionResponseChoice(
            index=idx,
            message=ChatMessage(role="assistant", content=ret_item["text"]),
            finish_reason=ret_item["meta_info"]["finish_reason"],
        )

        choices.append(choice_data)
        total_prompt_tokens = prompt_tokens
        total_completion_tokens += completion_tokens

384
    response = ChatCompletionResponse(
385
        id=ret[0]["meta_info"]["id"],
386
        model=request.model,
387
        choices=choices,
388
        usage=UsageInfo(
389
390
391
            prompt_tokens=total_prompt_tokens,
            completion_tokens=total_completion_tokens,
            total_tokens=total_prompt_tokens + total_completion_tokens,
392
393
        ),
    )
394

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    return response


def to_openai_style_logprobs(
    prefill_token_logprobs=None,
    decode_token_logprobs=None,
    prefill_top_logprobs=None,
    decode_top_logprobs=None,
):
    ret_logprobs = LogProbs()

    def append_token_logprobs(token_logprobs):
        for logprob, _, token_text in token_logprobs:
            ret_logprobs.tokens.append(token_text)
            ret_logprobs.token_logprobs.append(logprob)

411
            # Not supported yet
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            ret_logprobs.text_offset.append(-1)

    def append_top_logprobs(top_logprobs):
        for tokens in top_logprobs:
            if tokens is not None:
                ret_logprobs.top_logprobs.append(
                    {token[2]: token[0] for token in tokens}
                )
            else:
                ret_logprobs.top_logprobs.append(None)

    if prefill_token_logprobs is not None:
        append_token_logprobs(prefill_token_logprobs)
    if decode_token_logprobs is not None:
        append_token_logprobs(decode_token_logprobs)
    if prefill_top_logprobs is not None:
        append_top_logprobs(prefill_top_logprobs)
    if decode_top_logprobs is not None:
        append_top_logprobs(decode_top_logprobs)

Liangsheng Yin's avatar
Liangsheng Yin committed
432
    return ret_logprobs