adapter.py 68.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""Conversion between OpenAI APIs and native SRT APIs"""
Liangsheng Yin's avatar
Liangsheng Yin committed
15

16
import asyncio
17
import json
18
import logging
19
import os
20
21
import time
import uuid
22
from http import HTTPStatus
23
from typing import Dict, List
24

25
from fastapi import HTTPException, Request, UploadFile
26
from fastapi.responses import ORJSONResponse, StreamingResponse
27
from pydantic import ValidationError
28

29
30
31
32
33
34
35
try:
    from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError:
    # Before outlines 0.0.47, convert_json_schema_to_str is under
    # outlines.integrations.utils
    from outlines.integrations.utils import convert_json_schema_to_str

36
37
38
39
from sglang.srt.code_completion_parser import (
    generate_completion_prompt_from_request,
    is_completion_template_defined,
)
40
41
42
43
44
from sglang.srt.conversation import (
    Conversation,
    SeparatorStyle,
    chat_template_exists,
    generate_chat_conv,
45
    generate_embedding_convs,
46
47
    register_conv_template,
)
YAMY's avatar
YAMY committed
48
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
Ying Sheng's avatar
Ying Sheng committed
49
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
Mingyi's avatar
Mingyi committed
50
from sglang.srt.openai_api.protocol import (
51
52
    BatchRequest,
    BatchResponse,
53
54
55
56
57
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
58
    ChatCompletionTokenLogprob,
59
    ChatMessage,
60
    ChoiceLogprobs,
61
62
63
64
65
66
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    DeltaMessage,
Ying Sheng's avatar
Ying Sheng committed
67
    EmbeddingObject,
68
69
    EmbeddingRequest,
    EmbeddingResponse,
70
    ErrorResponse,
71
    FileDeleteResponse,
72
73
    FileRequest,
    FileResponse,
Tanjiro's avatar
Tanjiro committed
74
    FunctionResponse,
75
    LogProbs,
76
    MultimodalEmbeddingInput,
Tanjiro's avatar
Tanjiro committed
77
    ToolCall,
78
    TopLogprob,
79
80
    UsageInfo,
)
Xihuai Wang's avatar
Xihuai Wang committed
81
from sglang.srt.reasoning_parser import ReasoningParser
82
from sglang.utils import get_exception_traceback
83

84
85
logger = logging.getLogger(__name__)

86
87
chat_template_name = None

Liangsheng Yin's avatar
Liangsheng Yin committed
88

89
90
91
92
93
94
95
96
97
98
class FileMetadata:
    def __init__(self, filename: str, purpose: str):
        self.filename = filename
        self.purpose = purpose


# In-memory storage for batch jobs and files
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
99
# map file id to file path in SGLang backend
100
101
102
103
104
105
file_id_storage: Dict[str, str] = {}

# backend storage directory
storage_dir = None


106
107
108
def create_error_response(
    message: str,
    err_type: str = "BadRequestError",
109
110
111
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
112
    return ORJSONResponse(content=error.model_dump(), status_code=error.code)
113
114
115
116
117


def create_streaming_error_response(
    message: str,
    err_type: str = "BadRequestError",
118
119
120
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
121
122
123
124
    json_str = json.dumps({"error": error.model_dump()})
    return json_str


125
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
126
127
    global chat_template_name

128
129
130
    logger.info(
        f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
    )
131

132
133
134
135
136
137
    if not chat_template_exists(chat_template_arg):
        if not os.path.exists(chat_template_arg):
            raise RuntimeError(
                f"Chat template {chat_template_arg} is not a built-in template name "
                "or a valid chat template file path."
            )
138
139
140
        if chat_template_arg.endswith(".jinja"):
            with open(chat_template_arg, "r") as f:
                chat_template = "".join(f.readlines()).strip("\n")
141
142
143
            tokenizer_manager.tokenizer.chat_template = chat_template.replace(
                "\\n", "\n"
            )
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            chat_template_name = None
        else:
            assert chat_template_arg.endswith(
                ".json"
            ), "unrecognized format of chat template file"
            with open(chat_template_arg, "r") as filep:
                template = json.load(filep)
                try:
                    sep_style = SeparatorStyle[template["sep_style"]]
                except KeyError:
                    raise ValueError(
                        f"Unknown separator style: {template['sep_style']}"
                    ) from None
                register_conv_template(
                    Conversation(
                        name=template["name"],
                        system_template=template["system"] + "\n{system_message}",
                        system_message=template.get("system_message", ""),
                        roles=(template["user"], template["assistant"]),
                        sep_style=sep_style,
                        sep=template.get("sep", "\n"),
                        stop_str=template["stop_str"],
                    ),
                    override=True,
                )
            chat_template_name = template["name"]
170
171
172
    else:
        chat_template_name = chat_template_arg

173
174
175
176
    # Check chat-template
    # TODO:
    # 1. Do not import any code from sglang.lang
    # 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
177

178

179
180
181
async def v1_files_create(
    file: UploadFile, purpose: str, file_storage_path: str = None
):
182
183
    try:
        global storage_dir
184
185
        if file_storage_path:
            storage_dir = file_storage_path
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        # Read the file content
        file_content = await file.read()

        # Create an instance of RequestBody
        request_body = FileRequest(file=file_content, purpose=purpose)

        # Save the file to the sglang_oai_storage directory
        os.makedirs(storage_dir, exist_ok=True)
        file_id = f"backend_input_file-{uuid.uuid4()}"
        filename = f"{file_id}.jsonl"
        file_path = os.path.join(storage_dir, filename)

        with open(file_path, "wb") as f:
            f.write(request_body.file)

        # add info to global file map
        file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
        file_id_storage[file_id] = file_path

        # Return the response in the required format
        response = FileResponse(
            id=file_id,
            bytes=len(request_body.file),
            created_at=int(time.time()),
            filename=file.filename,
            purpose=request_body.purpose,
        )
        file_id_response[file_id] = response

        return response
    except ValidationError as e:
        return {"error": "Invalid input", "details": e.errors()}


220
221
222
223
224
225
226
227
228
229
230
231
232
233
async def v1_delete_file(file_id: str):
    # Retrieve the file job from the in-memory storage
    file_response = file_id_response.get(file_id)
    if file_response is None:
        raise HTTPException(status_code=404, detail="File not found")
    file_path = file_id_storage.get(file_id)
    if file_path is None:
        raise HTTPException(status_code=404, detail="File not found")
    os.remove(file_path)
    del file_id_response[file_id]
    del file_id_storage[file_id]
    return FileDeleteResponse(id=file_id, deleted=True)


234
async def v1_batches(tokenizer_manager, raw_request: Request):
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    try:
        body = await raw_request.json()

        batch_request = BatchRequest(**body)

        batch_id = f"batch_{uuid.uuid4()}"

        # Create an instance of BatchResponse
        batch_response = BatchResponse(
            id=batch_id,
            endpoint=batch_request.endpoint,
            input_file_id=batch_request.input_file_id,
            completion_window=batch_request.completion_window,
            created_at=int(time.time()),
            metadata=batch_request.metadata,
        )

        batch_storage[batch_id] = batch_response

        # Start processing the batch asynchronously
255
        asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
256
257
258
259
260
261
262
263
264
265

        # Return the initial batch_response
        return batch_response

    except ValidationError as e:
        return {"error": "Invalid input", "details": e.errors()}
    except Exception as e:
        return {"error": str(e)}


266
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    try:
        # Update the batch status to "in_progress"
        batch_storage[batch_id].status = "in_progress"
        batch_storage[batch_id].in_progress_at = int(time.time())

        # Retrieve the input file content
        input_file_request = file_id_request.get(batch_request.input_file_id)
        if not input_file_request:
            raise ValueError("Input file not found")

        # Parse the JSONL file and process each request
        input_file_path = file_id_storage.get(batch_request.input_file_id)
        with open(input_file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()

        total_requests = len(lines)
        completed_requests = 0
        failed_requests = 0

        all_ret = []
        end_point = batch_storage[batch_id].endpoint
        file_request_list = []
        all_requests = []
290
        request_ids = []
291
        for line_id, line in enumerate(lines):
292
293
294
            request_data = json.loads(line)
            file_request_list.append(request_data)
            body = request_data["body"]
295
            request_ids.append(f"{batch_id}-req_{line_id}")
296
297
298
299
300
301

            # Although streaming is supported for standalone completions, it is not supported in
            # batch mode (multiple completions in single request).
            if body.get("stream", False):
                raise ValueError("Streaming requests are not supported in batch mode")

302
303
304
305
            if end_point == "/v1/chat/completions":
                all_requests.append(ChatCompletionRequest(**body))
            elif end_point == "/v1/completions":
                all_requests.append(CompletionRequest(**body))
306

307
308
        if end_point == "/v1/chat/completions":
            adapted_request, request = v1_chat_generate_request(
309
                all_requests, tokenizer_manager, request_ids=request_ids
310
311
            )
        elif end_point == "/v1/completions":
312
313
314
315
            adapted_request, request = v1_generate_request(
                all_requests, request_ids=request_ids
            )

316
        try:
317
            created = int(time.time())
318
            ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
319
320
321
            if not isinstance(ret, list):
                ret = [ret]
            if end_point == "/v1/chat/completions":
322
323
324
                responses = v1_chat_generate_response(
                    request,
                    ret,
325
                    created,
326
                    to_file=True,
327
328
                    cache_report=tokenizer_manager.server_args.enable_cache_report,
                    tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
329
                )
330
            else:
yichuan~'s avatar
yichuan~ committed
331
                responses = v1_generate_response(
332
333
334
335
336
337
                    request,
                    ret,
                    tokenizer_manager,
                    created,
                    to_file=True,
                    cache_report=tokenizer_manager.server_args.enable_cache_report,
yichuan~'s avatar
yichuan~ committed
338
                )
339
340

        except Exception as e:
341
342
            logger.error(f"error: {get_exception_traceback()}")
            responses = []
343
344
345
346
347
348
349
350
351
352
            error_json = {
                "id": f"batch_req_{uuid.uuid4()}",
                "custom_id": request_data.get("custom_id"),
                "response": None,
                "error": {"message": str(e)},
            }
            all_ret.append(error_json)
            failed_requests += len(file_request_list)

        for idx, response in enumerate(responses):
353
            # the batch_req here can be changed to be named within a batch granularity
354
355
356
357
358
359
360
361
            response_json = {
                "id": f"batch_req_{uuid.uuid4()}",
                "custom_id": file_request_list[idx].get("custom_id"),
                "response": response,
                "error": None,
            }
            all_ret.append(response_json)
            completed_requests += 1
362

363
364
365
366
367
368
369
370
371
372
373
374
        # Write results to a new file
        output_file_id = f"backend_result_file-{uuid.uuid4()}"
        global storage_dir
        output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
        with open(output_file_path, "w", encoding="utf-8") as f:
            for ret in all_ret:
                f.write(json.dumps(ret) + "\n")

        # Update batch response with output file information
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.output_file_id = output_file_id
        file_id_storage[output_file_id] = output_file_path
375
376
377
378
379
380
381
        file_id_response[output_file_id] = FileResponse(
            id=output_file_id,
            bytes=os.path.getsize(output_file_path),
            created_at=int(time.time()),
            filename=f"{output_file_id}.jsonl",
            purpose="batch_result",
        )
382
383
384
385
386
387
388
389
390
391
        # Update batch status to "completed"
        retrieve_batch.status = "completed"
        retrieve_batch.completed_at = int(time.time())
        retrieve_batch.request_counts = {
            "total": total_requests,
            "completed": completed_requests,
            "failed": failed_requests,
        }

    except Exception as e:
392
        logger.error(f"error: {e}")
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        # Update batch status to "failed"
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "failed"
        retrieve_batch.failed_at = int(time.time())
        retrieve_batch.errors = {"message": str(e)}


async def v1_retrieve_batch(batch_id: str):
    # Retrieve the batch job from the in-memory storage
    batch_response = batch_storage.get(batch_id)
    if batch_response is None:
        raise HTTPException(status_code=404, detail="Batch not found")

    return batch_response


409
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
410
411
412
413
414
415
416
417
418
419
    # Retrieve the batch job from the in-memory storage
    batch_response = batch_storage.get(batch_id)
    if batch_response is None:
        raise HTTPException(status_code=404, detail="Batch not found")

    # Only do cancal when status is "validating" or "in_progress"
    if batch_response.status in ["validating", "in_progress"]:
        # Start cancelling the batch asynchronously
        asyncio.create_task(
            cancel_batch(
420
                tokenizer_manager=tokenizer_manager,
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                batch_id=batch_id,
                input_file_id=batch_response.input_file_id,
            )
        )

        # Update batch status to "cancelling"
        batch_response.status = "cancelling"

        return batch_response
    else:
        raise HTTPException(
            status_code=500,
            detail=f"Current status is {batch_response.status}, no need to cancel",
        )


437
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    try:
        # Update the batch status to "cancelling"
        batch_storage[batch_id].status = "cancelling"

        # Retrieve the input file content
        input_file_request = file_id_request.get(input_file_id)
        if not input_file_request:
            raise ValueError("Input file not found")

        # Parse the JSONL file and process each request
        input_file_path = file_id_storage.get(input_file_id)
        with open(input_file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()

        # Cancel requests by request_ids
453
454
        for line_id in range(len(lines)):
            rid = f"{batch_id}-req_{line_id}"
455
            tokenizer_manager.abort_request(rid=rid)
456
457
458
459
460
461
462
463
464
465
466
467
468

        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "cancelled"

    except Exception as e:
        logger.error("error in SGLang:", e)
        # Update batch status to "failed"
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "failed"
        retrieve_batch.failed_at = int(time.time())
        retrieve_batch.errors = {"message": str(e)}


469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
async def v1_retrieve_file(file_id: str):
    # Retrieve the batch job from the in-memory storage
    file_response = file_id_response.get(file_id)
    if file_response is None:
        raise HTTPException(status_code=404, detail="File not found")
    return file_response


async def v1_retrieve_file_content(file_id: str):
    file_pth = file_id_storage.get(file_id)
    if not file_pth or not os.path.exists(file_pth):
        raise HTTPException(status_code=404, detail="File not found")

    def iter_file():
        with open(file_pth, mode="rb") as file_like:
            yield from file_like

    return StreamingResponse(iter_file(), media_type="application/octet-stream")


489
490
491
def v1_generate_request(
    all_requests: List[CompletionRequest], request_ids: List[str] = None
):
492
493
494
495
496
497
498
499
500
501
502
    if len(all_requests) > 1:
        first_prompt_type = type(all_requests[0].prompt)
        for request in all_requests:
            assert (
                type(request.prompt) is first_prompt_type
            ), "All prompts must be of the same type in file input settings"
            if request.n > 1:
                raise ValueError(
                    "Parallel sampling is not supported for completions from files"
                )

503
504
    prompts = []
    sampling_params_list = []
505
    return_logprobs = []
506
    logprob_start_lens = []
507
    top_logprobs_nums = []
508
    lora_paths = []
yichuan~'s avatar
yichuan~ committed
509

510
    for request in all_requests:
511
        # NOTE: with openai API, the prompt's logprobs are always not computed
512
        if request.echo and request.logprobs:
513
            logger.warning(
514
                "Echo is not compatible with logprobs. "
515
                "To compute logprobs of input prompt, please use the native /generate API."
516
517
            )

518
519
520
521
522
        prompt = request.prompt
        if is_completion_template_defined():
            prompt = generate_completion_prompt_from_request(request)
        prompts.append(prompt)

523
        lora_paths.append(request.lora_path)
524
525
526
527
        if request.echo and request.logprobs:
            current_logprob_start_len = 0
        else:
            current_logprob_start_len = -1
528
529
530
531
532
533
534
535
        sampling_params_list.append(
            {
                "temperature": request.temperature,
                "max_new_tokens": request.max_tokens,
                "min_new_tokens": request.min_tokens,
                "stop": request.stop,
                "stop_token_ids": request.stop_token_ids,
                "top_p": request.top_p,
536
537
                "top_k": request.top_k,
                "min_p": request.min_p,
538
539
540
541
542
                "presence_penalty": request.presence_penalty,
                "frequency_penalty": request.frequency_penalty,
                "repetition_penalty": request.repetition_penalty,
                "regex": request.regex,
                "json_schema": request.json_schema,
543
                "ebnf": request.ebnf,
544
545
                "n": request.n,
                "no_stop_trim": request.no_stop_trim,
546
547
                "ignore_eos": request.ignore_eos,
                "skip_special_tokens": request.skip_special_tokens,
548
549
            }
        )
550
        return_logprobs.append(request.logprobs is not None)
551
        logprob_start_lens.append(current_logprob_start_len)
552
553
554
        top_logprobs_nums.append(
            request.logprobs if request.logprobs is not None else 0
        )
555
556

    if len(all_requests) == 1:
557
558
559
560
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
            prompt_kwargs = {"text": prompts[0]}
        else:
            prompt_kwargs = {"input_ids": prompts[0]}
561
        sampling_params_list = sampling_params_list[0]
562
        return_logprobs = return_logprobs[0]
563
        logprob_start_lens = logprob_start_lens[0]
564
        top_logprobs_nums = top_logprobs_nums[0]
565
        lora_paths = lora_paths[0]
566
    else:
567
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
568
569
570
            prompt_kwargs = {"text": prompts}
        else:
            prompt_kwargs = {"input_ids": prompts}
yichuan~'s avatar
yichuan~ committed
571

572
    adapted_request = GenerateReqInput(
573
        **prompt_kwargs,
574
        sampling_params=sampling_params_list,
575
576
        return_logprob=return_logprobs,
        top_logprobs_num=top_logprobs_nums,
577
        logprob_start_len=logprob_start_lens,
578
        return_text_in_logprobs=True,
579
        stream=all_requests[0].stream,
580
        rid=request_ids,
581
        lora_path=lora_paths,
582
    )
yichuan~'s avatar
yichuan~ committed
583

584
    return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
585
586


587
588
589
def v1_generate_response(
    request, ret, tokenizer_manager, created, to_file=False, cache_report=False
):
590
591
592
    choices = []
    echo = False

yichuan~'s avatar
yichuan~ committed
593
    if (not isinstance(request, list)) and request.echo:
594
        # TODO: handle the case propmt is token ids
yichuan~'s avatar
yichuan~ committed
595
596
        if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
            # for the case of multiple str prompts
597
            prompts = request.prompt
yichuan~'s avatar
yichuan~ committed
598
599
600
        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
            # for the case of multiple token ids prompts
            prompts = [
601
                tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
yichuan~'s avatar
yichuan~ committed
602
603
604
605
606
                for prompt in request.prompt
            ]
        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
            # for the case of single token ids prompt
            prompts = [
607
608
609
                tokenizer_manager.tokenizer.decode(
                    request.prompt, skip_special_tokens=True
                )
yichuan~'s avatar
yichuan~ committed
610
            ]
611
        else:
yichuan~'s avatar
yichuan~ committed
612
            # for the case of single str prompt
613
614
615
616
617
            prompts = [request.prompt]
        echo = True

    for idx, ret_item in enumerate(ret):
        text = ret_item["text"]
yichuan~'s avatar
yichuan~ committed
618
        if isinstance(request, list) and request[idx].echo:
619
620
            echo = True
            text = request[idx].prompt + text
621
        if echo and not isinstance(request, list):
yichuan~'s avatar
yichuan~ committed
622
623
            prompt_index = idx // request.n
            text = prompts[prompt_index] + text
624
625

        logprobs = False
626
        if isinstance(request, list) and request[idx].logprobs is not None:
627
            logprobs = True
628
        elif (not isinstance(request, list)) and request.logprobs is not None:
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
            logprobs = True
        if logprobs:
            if echo:
                input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
                input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
            else:
                input_token_logprobs = None
                input_top_logprobs = None

            logprobs = to_openai_style_logprobs(
                input_token_logprobs=input_token_logprobs,
                input_top_logprobs=input_top_logprobs,
                output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
                output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
            )
        else:
            logprobs = None

647
648
        finish_reason = ret_item["meta_info"]["finish_reason"]

649
        if to_file:
650
            # to make the choise data json serializable
651
652
653
654
            choice_data = {
                "index": 0,
                "text": text,
                "logprobs": logprobs,
655
656
657
658
659
                "finish_reason": (finish_reason["type"] if finish_reason else ""),
                "matched_stop": (
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
660
                ),
661
662
663
664
665
666
            }
        else:
            choice_data = CompletionResponseChoice(
                index=idx,
                text=text,
                logprobs=logprobs,
667
668
669
670
671
                finish_reason=(finish_reason["type"] if finish_reason else ""),
                matched_stop=(
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
672
                ),
673
674
675
676
677
678
679
680
681
682
683
            )

        choices.append(choice_data)

    if to_file:
        responses = []
        for i, choice in enumerate(choices):
            response = {
                "status_code": 200,
                "request_id": ret[i]["meta_info"]["id"],
                "body": {
684
                    # remain the same but if needed we can change that
685
686
                    "id": ret[i]["meta_info"]["id"],
                    "object": "text_completion",
687
                    "created": created,
688
689
690
691
692
693
694
695
696
697
698
699
700
701
                    "model": request[i].model,
                    "choices": choice,
                    "usage": {
                        "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
                        "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
                        "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
                        + ret[i]["meta_info"]["completion_tokens"],
                    },
                    "system_fingerprint": None,
                },
            }
            responses.append(response)
        return responses
    else:
702
703
704
        prompt_tokens = sum(
            ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
        )
705
        completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
706
        cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
707
708
709
        response = CompletionResponse(
            id=ret[0]["meta_info"]["id"],
            model=request.model,
710
            created=created,
711
712
            choices=choices,
            usage=UsageInfo(
yichuan~'s avatar
yichuan~ committed
713
                prompt_tokens=prompt_tokens,
714
                completion_tokens=completion_tokens,
yichuan~'s avatar
yichuan~ committed
715
                total_tokens=prompt_tokens + completion_tokens,
716
717
718
                prompt_tokens_details=(
                    {"cached_tokens": cached_tokens} if cache_report else None
                ),
719
720
721
722
723
            ),
        )
    return response


724
async def v1_completions(tokenizer_manager, raw_request: Request):
725
726
    request_json = await raw_request.json()
    all_requests = [CompletionRequest(**request_json)]
727
    created = int(time.time())
728
    adapted_request, request = v1_generate_request(all_requests)
729
730
731
732

    if adapted_request.stream:

        async def generate_stream_resp():
733
734
735
736
            stream_buffers = {}
            n_prev_tokens = {}
            prompt_tokens = {}
            completion_tokens = {}
737
738
            cached_tokens = {}

739
            try:
740
                async for content in tokenizer_manager.generate_request(
741
742
                    adapted_request, raw_request
                ):
743
                    index = content.get("index", 0)
744
745
746
747

                    stream_buffer = stream_buffers.get(index, "")
                    n_prev_token = n_prev_tokens.get(index, 0)

748
                    text = content["text"]
749
750
                    prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
                    completion_tokens[index] = content["meta_info"]["completion_tokens"]
751
                    cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
752
753
754

                    if not stream_buffer:  # The first chunk
                        if request.echo:
yichuan~'s avatar
yichuan~ committed
755
756
757
                            if isinstance(request.prompt, str):
                                # for the case of single str prompts
                                prompts = request.prompt
758
759
760
761
762
763
                            elif isinstance(request.prompt, list):
                                if isinstance(request.prompt[0], str):
                                    # for the case of multiple str prompts
                                    prompts = request.prompt[index // request.n]
                                elif isinstance(request.prompt[0], int):
                                    # for the case of single token ids prompt
764
                                    prompts = tokenizer_manager.tokenizer.decode(
765
766
767
768
769
770
                                        request.prompt, skip_special_tokens=True
                                    )
                                elif isinstance(request.prompt[0], list) and isinstance(
                                    request.prompt[0][0], int
                                ):
                                    # for the case of multiple token ids prompts
771
                                    prompts = tokenizer_manager.tokenizer.decode(
772
773
774
                                        request.prompt[index // request.n],
                                        skip_special_tokens=True,
                                    )
yichuan~'s avatar
yichuan~ committed
775

776
                            # Prepend prompt in response text.
yichuan~'s avatar
yichuan~ committed
777
                            text = prompts + text
778

779
                    if request.logprobs is not None:
780
781
                        # The first chunk and echo is enabled.
                        if not stream_buffer and request.echo:
782
783
                            input_token_logprobs = content["meta_info"][
                                "input_token_logprobs"
784
                            ]
785
786
                            input_top_logprobs = content["meta_info"][
                                "input_top_logprobs"
787
788
                            ]
                        else:
789
790
                            input_token_logprobs = None
                            input_top_logprobs = None
791
792

                        logprobs = to_openai_style_logprobs(
793
794
795
796
                            input_token_logprobs=input_token_logprobs,
                            input_top_logprobs=input_top_logprobs,
                            output_token_logprobs=content["meta_info"][
                                "output_token_logprobs"
797
                            ][n_prev_token:],
798
799
                            output_top_logprobs=content["meta_info"][
                                "output_top_logprobs"
800
                            ][n_prev_token:],
801
                        )
802
                        n_prev_token = len(
803
                            content["meta_info"]["output_token_logprobs"]
804
                        )
805
                    else:
806
                        logprobs = None
807

808
                    delta = text[len(stream_buffer) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
809
                    stream_buffer = stream_buffer + delta
810
                    finish_reason = content["meta_info"]["finish_reason"]
811
                    choice_data = CompletionResponseStreamChoice(
812
                        index=index,
813
814
                        text=delta,
                        logprobs=logprobs,
815
816
817
818
819
                        finish_reason=(finish_reason["type"] if finish_reason else ""),
                        matched_stop=(
                            finish_reason["matched"]
                            if finish_reason and "matched" in finish_reason
                            else None
820
                        ),
821
822
823
                    )
                    chunk = CompletionStreamResponse(
                        id=content["meta_info"]["id"],
824
                        created=created,
825
826
827
828
                        object="text_completion",
                        choices=[choice_data],
                        model=request.model,
                    )
829
830
831
832

                    stream_buffers[index] = stream_buffer
                    n_prev_tokens[index] = n_prev_token

833
                    yield f"data: {chunk.model_dump_json()}\n\n"
834
                if request.stream_options and request.stream_options.include_usage:
835
836
837
838
839
840
841
842
                    total_prompt_tokens = sum(
                        tokens
                        for i, tokens in prompt_tokens.items()
                        if i % request.n == 0
                    )
                    total_completion_tokens = sum(
                        tokens for tokens in completion_tokens.values()
                    )
843
844
845
846
847
848
849
850
                    cache_report = tokenizer_manager.server_args.enable_cache_report
                    if cache_report:
                        cached_tokens_sum = sum(
                            tokens for tokens in cached_tokens.values()
                        )
                        prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
                    else:
                        prompt_tokens_details = None
851
                    usage = UsageInfo(
852
853
854
                        prompt_tokens=total_prompt_tokens,
                        completion_tokens=total_completion_tokens,
                        total_tokens=total_prompt_tokens + total_completion_tokens,
855
                        prompt_tokens_details=prompt_tokens_details,
856
857
858
                    )

                    final_usage_chunk = CompletionStreamResponse(
859
                        id=content["meta_info"]["id"],
860
                        created=created,
861
862
863
864
865
                        choices=[],
                        model=request.model,
                        usage=usage,
                    )
                    final_usage_data = final_usage_chunk.model_dump_json(
866
                        exclude_none=True
867
868
                    )
                    yield f"data: {final_usage_data}\n\n"
869
870
871
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
872
873
            yield "data: [DONE]\n\n"

874
875
876
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
877
            background=tokenizer_manager.create_abort_task(adapted_request),
878
        )
879
880

    # Non-streaming response.
881
    try:
882
        ret = await tokenizer_manager.generate_request(
883
884
            adapted_request, raw_request
        ).__anext__()
885
886
    except ValueError as e:
        return create_error_response(str(e))
887

888
889
890
    if not isinstance(ret, list):
        ret = [ret]

891
892
893
894
895
896
897
    response = v1_generate_response(
        request,
        ret,
        tokenizer_manager,
        created,
        cache_report=tokenizer_manager.server_args.enable_cache_report,
    )
898
    return response
899

900

901
def v1_chat_generate_request(
902
    all_requests: List[ChatCompletionRequest],
903
    tokenizer_manager,
904
    request_ids: List[str] = None,
905
):
906
    input_ids = []
907
908
    sampling_params_list = []
    image_data_list = []
909
    return_logprobs = []
910
    logprob_start_lens = []
911
    top_logprobs_nums = []
912
    modalities_list = []
913
    lora_paths = []
914
915
916

    # NOTE: with openai API, the prompt's logprobs are always not computed

917
918
919
920
921
922
923
924
    for request in all_requests:
        # Prep the data needed for the underlying GenerateReqInput:
        #  - prompt: The full prompt string.
        #  - stop: Custom stop tokens.
        #  - image_data: None or a list of image strings (URLs or base64 strings).
        #    None skips any image processing in GenerateReqInput.
        if not isinstance(request.messages, str):
            # Apply chat template and its stop strings.
Tanjiro's avatar
Tanjiro committed
925
926
927
928
929
930
931
932
933
934
935
936
            tools = None
            if request.tools and request.tool_choice != "none":
                request.skip_special_tokens = False
                if not isinstance(request.tool_choice, str):
                    tools = [
                        item.function.model_dump()
                        for item in request.tools
                        if item.function.name == request.tool_choice.function.name
                    ]
                else:
                    tools = [item.function.model_dump() for item in request.tools]

937
            if chat_template_name is None:
938
939
940
941
942
943
944
945
946
947
948
949
950
                openai_compatible_messages = []
                for message in request.messages:
                    if isinstance(message.content, str):
                        openai_compatible_messages.append(
                            {"role": message.role, "content": message.content}
                        )
                    else:
                        content_list = message.dict()["content"]
                        for content in content_list:
                            if content["type"] == "text":
                                openai_compatible_messages.append(
                                    {"role": message.role, "content": content["text"]}
                                )
951
952
953
954
955
                if openai_compatible_messages[-1]["role"] == "assistant":
                    assistant_prefix = openai_compatible_messages[-1]["content"]
                    openai_compatible_messages = openai_compatible_messages[:-1]
                else:
                    assistant_prefix = None
YAMY's avatar
YAMY committed
956
957

                try:
958
                    prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
YAMY's avatar
YAMY committed
959
960
961
962
963
964
965
966
967
968
                        openai_compatible_messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        tools=tools,
                    )
                except:
                    #  This except branch will be triggered when the chosen model
                    #  has a different tools input format that is not compatiable
                    #  with openAI's apply_chat_template tool_call format, like Mistral.
                    tools = [t if "function" in t else {"function": t} for t in tools]
969
                    prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
YAMY's avatar
YAMY committed
970
971
972
973
974
975
                        openai_compatible_messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        tools=tools,
                    )

976
                if assistant_prefix:
977
978
979
980
981
                    encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
                    if (
                        encoded
                        and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
                    ):
982
983
                        encoded = encoded[1:]
                    prompt_ids += encoded
984
985
                stop = request.stop
                image_data = None
986
                modalities = []
987
            else:
988
989
990
                conv = generate_chat_conv(request, chat_template_name)
                prompt = conv.get_prompt()
                image_data = conv.image_data
991
                modalities = conv.modalities
992
993
994
995
996
997
                stop = conv.stop_str or []
                if request.stop:
                    if isinstance(request.stop, str):
                        stop.append(request.stop)
                    else:
                        stop.extend(request.stop)
998
                prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
999
        else:
1000
            # Use the raw prompt and stop strings if the messages is already a string.
yichuan~'s avatar
yichuan~ committed
1001
            prompt_ids = request.messages
1002
1003
            stop = request.stop
            image_data = None
1004
            modalities = []
1005
        input_ids.append(prompt_ids)
1006
        return_logprobs.append(request.logprobs)
1007
        logprob_start_lens.append(-1)
1008
        top_logprobs_nums.append(request.top_logprobs or 0)
1009
        lora_paths.append(request.lora_path)
1010
1011
1012
1013
1014
1015
1016
1017

        sampling_params = {
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "min_new_tokens": request.min_tokens,
            "stop": stop,
            "stop_token_ids": request.stop_token_ids,
            "top_p": request.top_p,
1018
1019
            "top_k": request.top_k,
            "min_p": request.min_p,
1020
1021
1022
1023
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
            "repetition_penalty": request.repetition_penalty,
            "regex": request.regex,
1024
            "ebnf": request.ebnf,
1025
            "n": request.n,
1026
            "no_stop_trim": request.no_stop_trim,
1027
            "ignore_eos": request.ignore_eos,
1028
            "skip_special_tokens": request.skip_special_tokens,
1029
        }
1030

1031
1032
1033
1034
        if request.response_format and request.response_format.type == "json_schema":
            sampling_params["json_schema"] = convert_json_schema_to_str(
                request.response_format.json_schema.schema_
            )
1035
1036
1037
1038
1039
1040
        elif (
            request.response_format and request.response_format.type == "structural_tag"
        ):
            sampling_params["structural_tag"] = convert_json_schema_to_str(
                request.response_format.model_dump(by_alias=True)
            )
1041
1042
        sampling_params_list.append(sampling_params)

1043
        image_data_list.append(image_data)
1044
        modalities_list.append(modalities)
1045
    if len(all_requests) == 1:
1046
1047
        if isinstance(input_ids[0], str):
            prompt_kwargs = {"text": input_ids[0]}
yichuan~'s avatar
yichuan~ committed
1048
        else:
1049
            prompt_kwargs = {"input_ids": input_ids[0]}
1050
        sampling_params_list = sampling_params_list[0]
1051
        image_data_list = image_data_list[0]
1052
        return_logprobs = return_logprobs[0]
1053
        logprob_start_lens = logprob_start_lens[0]
1054
        top_logprobs_nums = top_logprobs_nums[0]
1055
        modalities_list = modalities_list[0]
1056
        lora_paths = lora_paths[0]
yichuan~'s avatar
yichuan~ committed
1057
1058
1059
1060
1061
    else:
        if isinstance(input_ids[0], str):
            prompt_kwargs = {"text": input_ids}
        else:
            prompt_kwargs = {"input_ids": input_ids}
1062

1063
    adapted_request = GenerateReqInput(
yichuan~'s avatar
yichuan~ committed
1064
        **prompt_kwargs,
1065
        image_data=image_data_list,
1066
        sampling_params=sampling_params_list,
1067
        return_logprob=return_logprobs,
1068
        logprob_start_len=logprob_start_lens,
1069
1070
1071
        top_logprobs_num=top_logprobs_nums,
        stream=all_requests[0].stream,
        return_text_in_logprobs=True,
1072
        rid=request_ids,
1073
        modalities=modalities_list,
1074
        lora_path=lora_paths,
1075
    )
1076
1077

    return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
1078

1079

YAMY's avatar
YAMY committed
1080
def v1_chat_generate_response(
Xihuai Wang's avatar
Xihuai Wang committed
1081
1082
    request,
    ret,
1083
    created,
Xihuai Wang's avatar
Xihuai Wang committed
1084
1085
1086
1087
    to_file=False,
    cache_report=False,
    tool_call_parser=None,
    reasoning_parser=None,
YAMY's avatar
YAMY committed
1088
):
1089
1090
1091
    choices = []

    for idx, ret_item in enumerate(ret):
1092
        logprobs = False
yichuan~'s avatar
yichuan~ committed
1093
        if isinstance(request, list) and request[idx].logprobs:
1094
            logprobs = True
yichuan~'s avatar
yichuan~ committed
1095
        elif (not isinstance(request, list)) and request.logprobs:
1096
1097
1098
1099
1100
1101
1102
            logprobs = True
        if logprobs:
            logprobs = to_openai_style_logprobs(
                output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
                output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
            )
            token_logprobs = []
1103
1104
1105
            for token_idx, (token, logprob) in enumerate(
                zip(logprobs.tokens, logprobs.token_logprobs)
            ):
1106
1107
1108
                token_bytes = list(token.encode("utf-8"))
                top_logprobs = []
                if logprobs.top_logprobs:
1109
1110
1111
                    for top_token, top_logprob in logprobs.top_logprobs[
                        token_idx
                    ].items():
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
                        top_token_bytes = list(top_token.encode("utf-8"))
                        top_logprobs.append(
                            TopLogprob(
                                token=top_token,
                                bytes=top_token_bytes,
                                logprob=top_logprob,
                            )
                        )
                token_logprobs.append(
                    ChatCompletionTokenLogprob(
                        token=token,
                        bytes=token_bytes,
                        logprob=logprob,
                        top_logprobs=top_logprobs,
                    )
                )

            choice_logprobs = ChoiceLogprobs(content=token_logprobs)
        else:
            choice_logprobs = None
1132

1133
1134
        finish_reason = ret_item["meta_info"]["finish_reason"]

Tanjiro's avatar
Tanjiro committed
1135
1136
1137
1138
1139
1140
        tool_calls = None
        text = ret_item["text"]

        if isinstance(request, list):
            tool_choice = request[idx].tool_choice
            tools = request[idx].tools
Xihuai Wang's avatar
Xihuai Wang committed
1141
            separate_reasoning = request[idx].separate_reasoning
Tanjiro's avatar
Tanjiro committed
1142
1143
1144
        else:
            tool_choice = request.tool_choice
            tools = request.tools
Xihuai Wang's avatar
Xihuai Wang committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
            separate_reasoning = request.separate_reasoning

        if reasoning_parser and separate_reasoning:
            try:
                parser = ReasoningParser(
                    model_type=reasoning_parser, stream_reasoning=False
                )
                reasoning_text, text = parser.parse_non_stream(text)
            except Exception as e:
                logger.error(f"Exception: {e}")
                return create_error_response(
                    HTTPStatus.BAD_REQUEST,
                    "Failed to parse reasoning related info to json format!",
                )
        else:
            reasoning_text = None
Tanjiro's avatar
Tanjiro committed
1161

1162
1163
1164
1165
1166
1167
1168
        if tool_choice != "none" and tools:
            parser = FunctionCallParser(tools, tool_call_parser)
            if parser.has_tool_call(text):
                if finish_reason["type"] == "stop":
                    finish_reason["type"] = "tool_calls"
                    finish_reason["matched"] = None
                try:
1169
                    text, call_info_list = parser.parse_non_stream(text)
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
                    tool_calls = [
                        ToolCall(
                            id=str(call_info.tool_index),
                            function=FunctionResponse(
                                name=call_info.name, arguments=call_info.parameters
                            ),
                        )
                        for call_info in call_info_list
                    ]
                except Exception as e:
                    logger.error(f"Exception: {e}")
                    return create_error_response(
                        HTTPStatus.BAD_REQUEST,
                        "Failed to parse fc related info to json format!",
Tanjiro's avatar
Tanjiro committed
1184
1185
                    )

1186
        if to_file:
1187
            # to make the choice data json serializable
1188
1189
            choice_data = {
                "index": 0,
Tanjiro's avatar
Tanjiro committed
1190
1191
                "message": {
                    "role": "assistant",
1192
                    "content": text if text else None,
Tanjiro's avatar
Tanjiro committed
1193
                    "tool_calls": tool_calls,
1194
                    "reasoning_content": reasoning_text if reasoning_text else None,
Tanjiro's avatar
Tanjiro committed
1195
                },
1196
                "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
1197
1198
1199
1200
1201
                "finish_reason": (finish_reason["type"] if finish_reason else ""),
                "matched_stop": (
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
1202
                ),
1203
            }
1204
        else:
1205
1206
            choice_data = ChatCompletionResponseChoice(
                index=idx,
Tanjiro's avatar
Tanjiro committed
1207
1208
                message=ChatMessage(
                    role="assistant",
1209
                    content=text if text else None,
Tanjiro's avatar
Tanjiro committed
1210
                    tool_calls=tool_calls,
1211
                    reasoning_content=reasoning_text if reasoning_text else None,
Tanjiro's avatar
Tanjiro committed
1212
                ),
1213
                logprobs=choice_logprobs,
1214
1215
1216
1217
1218
                finish_reason=(finish_reason["type"] if finish_reason else ""),
                matched_stop=(
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
1219
                ),
1220
1221
1222
            )

        choices.append(choice_data)
1223

1224
1225
1226
1227
1228
1229
1230
1231
    if to_file:
        responses = []

        for i, choice in enumerate(choices):
            response = {
                "status_code": 200,
                "request_id": ret[i]["meta_info"]["id"],
                "body": {
1232
                    # remain the same but if needed we can change that
1233
1234
                    "id": ret[i]["meta_info"]["id"],
                    "object": "chat.completion",
1235
                    "created": created,
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
                    "model": request[i].model,
                    "choices": choice,
                    "usage": {
                        "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
                        "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
                        "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
                        + ret[i]["meta_info"]["completion_tokens"],
                    },
                    "system_fingerprint": None,
                },
            }
            responses.append(response)
        return responses
1249
    else:
1250
1251
1252
1253
        prompt_tokens = sum(
            ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
        )
        completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
1254
        cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
1255
1256
        response = ChatCompletionResponse(
            id=ret[0]["meta_info"]["id"],
1257
            created=created,
1258
1259
1260
            model=request.model,
            choices=choices,
            usage=UsageInfo(
1261
1262
1263
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
1264
1265
1266
                prompt_tokens_details=(
                    {"cached_tokens": cached_tokens} if cache_report else None
                ),
1267
1268
1269
            ),
        )
        return response
1270

1271

1272
1273
1274
async def v1_chat_completions(
    tokenizer_manager, raw_request: Request, cache_report=False
):
1275
1276
    request_json = await raw_request.json()
    all_requests = [ChatCompletionRequest(**request_json)]
1277
    created = int(time.time())
1278
    adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1279
1280

    if adapted_request.stream:
YAMY's avatar
YAMY committed
1281
        parser_dict = {}
Xihuai Wang's avatar
Xihuai Wang committed
1282
        reasoning_parser_dict = {}
1283
1284

        async def generate_stream_resp():
1285
1286
1287
1288
1289
            is_firsts = {}
            stream_buffers = {}
            n_prev_tokens = {}
            prompt_tokens = {}
            completion_tokens = {}
1290
            cached_tokens = {}
1291
            try:
1292
                async for content in tokenizer_manager.generate_request(
1293
1294
                    adapted_request, raw_request
                ):
1295
                    index = content.get("index", 0)
YAMY's avatar
YAMY committed
1296
                    text = content["text"]
1297
1298
1299
1300
1301
1302
1303

                    is_first = is_firsts.get(index, True)
                    stream_buffer = stream_buffers.get(index, "")
                    n_prev_token = n_prev_tokens.get(index, 0)

                    prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
                    completion_tokens[index] = content["meta_info"]["completion_tokens"]
1304
                    cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
yichuan~'s avatar
yichuan~ committed
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
                    if request.logprobs:
                        logprobs = to_openai_style_logprobs(
                            output_token_logprobs=content["meta_info"][
                                "output_token_logprobs"
                            ][n_prev_token:],
                            output_top_logprobs=content["meta_info"][
                                "output_top_logprobs"
                            ][n_prev_token:],
                        )

                        n_prev_token = len(
                            content["meta_info"]["output_token_logprobs"]
                        )
                        token_logprobs = []
                        for token, logprob in zip(
                            logprobs.tokens, logprobs.token_logprobs
                        ):
                            token_bytes = list(token.encode("utf-8"))
                            top_logprobs = []
                            if logprobs.top_logprobs:
                                for top_token, top_logprob in logprobs.top_logprobs[
                                    0
                                ].items():
                                    top_token_bytes = list(top_token.encode("utf-8"))
                                    top_logprobs.append(
                                        TopLogprob(
                                            token=top_token,
                                            bytes=top_token_bytes,
                                            logprob=top_logprob,
                                        )
                                    )
                            token_logprobs.append(
                                ChatCompletionTokenLogprob(
                                    token=token,
                                    bytes=token_bytes,
                                    logprob=logprob,
                                    top_logprobs=top_logprobs,
                                )
                            )

                        choice_logprobs = ChoiceLogprobs(content=token_logprobs)

                    else:
                        choice_logprobs = None

1350
                    finish_reason = content["meta_info"]["finish_reason"]
Xihuai Wang's avatar
Xihuai Wang committed
1351
1352
1353
                    finish_reason_type = (
                        finish_reason["type"] if finish_reason else None
                    )
1354

1355
1356
1357
                    if is_first:
                        # First chunk with role
                        is_first = False
Xihuai Wang's avatar
Xihuai Wang committed
1358
1359
1360
1361
                        if (
                            tokenizer_manager.server_args.reasoning_parser
                            and request.separate_reasoning
                        ):
1362
1363
1364
                            delta = DeltaMessage(
                                role="assistant", reasoning_content=None
                            )
Xihuai Wang's avatar
Xihuai Wang committed
1365
                        else:
1366
                            delta = DeltaMessage(role="assistant", content=None)
1367
                        choice_data = ChatCompletionResponseStreamChoice(
1368
                            index=index,
Xihuai Wang's avatar
Xihuai Wang committed
1369
                            delta=delta,
1370
                            finish_reason=(
Xihuai Wang's avatar
Xihuai Wang committed
1371
1372
1373
                                None
                                if finish_reason_type and len(finish_reason_type) == 0
                                else finish_reason_type
1374
1375
1376
1377
1378
                            ),
                            matched_stop=(
                                finish_reason["matched"]
                                if finish_reason and "matched" in finish_reason
                                else None
1379
                            ),
yichuan~'s avatar
yichuan~ committed
1380
                            logprobs=choice_logprobs,
1381
1382
1383
                        )
                        chunk = ChatCompletionStreamResponse(
                            id=content["meta_info"]["id"],
1384
                            created=created,
1385
1386
1387
1388
1389
1390
1391
                            choices=[choice_data],
                            model=request.model,
                        )
                        yield f"data: {chunk.model_dump_json()}\n\n"

                    text = content["text"]
                    delta = text[len(stream_buffer) :]
YAMY's avatar
YAMY committed
1392
                    new_stream_buffer = stream_buffer + delta
1393

Xihuai Wang's avatar
Xihuai Wang committed
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
                    if (
                        tokenizer_manager.server_args.reasoning_parser
                        and request.separate_reasoning
                    ):
                        if index not in reasoning_parser_dict:
                            reasoning_parser_dict[index] = ReasoningParser(
                                tokenizer_manager.server_args.reasoning_parser,
                                request.stream_reasoning,
                            )
                        reasoning_parser = reasoning_parser_dict[index]
                        reasoning_text, delta = reasoning_parser.parse_stream_chunk(
                            delta
                        )
                        if reasoning_text:
                            choice_data = ChatCompletionResponseStreamChoice(
                                index=index,
1410
1411
1412
1413
1414
                                delta=DeltaMessage(
                                    reasoning_content=(
                                        reasoning_text if reasoning_text else None
                                    )
                                ),
Xihuai Wang's avatar
Xihuai Wang committed
1415
1416
1417
1418
1419
1420
1421
1422
1423
                                finish_reason=(
                                    None
                                    if finish_reason_type
                                    and len(finish_reason_type) == 0
                                    else finish_reason_type
                                ),
                            )
                            chunk = ChatCompletionStreamResponse(
                                id=content["meta_info"]["id"],
1424
                                created=created,
Xihuai Wang's avatar
Xihuai Wang committed
1425
1426
1427
1428
1429
1430
1431
1432
1433
                                choices=[choice_data],
                                model=request.model,
                            )
                            yield f"data: {chunk.model_dump_json()}\n\n"
                        if (delta and len(delta) == 0) or not delta:
                            stream_buffers[index] = new_stream_buffer
                            is_firsts[index] = is_first
                            continue

YAMY's avatar
YAMY committed
1434
1435
1436
1437
                    if request.tool_choice != "none" and request.tools:
                        if index not in parser_dict:
                            parser_dict[index] = FunctionCallParser(
                                tools=request.tools,
1438
                                tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
YAMY's avatar
YAMY committed
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
                            )
                        parser = parser_dict[index]

                        # parse_increment => returns (normal_text, calls)
                        normal_text, calls = parser.parse_stream_chunk(delta)

                        # 1) if there's normal_text, output it as normal content
                        if normal_text:
                            choice_data = ChatCompletionResponseStreamChoice(
                                index=index,
1449
1450
1451
                                delta=DeltaMessage(
                                    content=normal_text if normal_text else None
                                ),
YAMY's avatar
YAMY committed
1452
                                finish_reason=(
Xihuai Wang's avatar
Xihuai Wang committed
1453
1454
1455
1456
                                    None
                                    if finish_reason_type
                                    and len(finish_reason_type) == 0
                                    else finish_reason_type
YAMY's avatar
YAMY committed
1457
1458
1459
1460
                                ),
                            )
                            chunk = ChatCompletionStreamResponse(
                                id=content["meta_info"]["id"],
1461
                                created=created,
YAMY's avatar
YAMY committed
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
                                choices=[choice_data],
                                model=request.model,
                            )
                            yield f"data: {chunk.model_dump_json()}\n\n"

                        # 2) if we found calls, we output them as separate chunk(s)
                        for call_item in calls:
                            # transform call_item -> FunctionResponse + ToolCall

                            if (
                                content["meta_info"]["finish_reason"]
                                and content["meta_info"]["finish_reason"]["type"]
                                == "stop"
                            ):
                                latest_delta_len = 0
                                if isinstance(call_item.parameters, str):
                                    latest_delta_len = len(call_item.parameters)

                                expected_call = json.dumps(
                                    parser.multi_format_parser.detectors[0]
                                    .prev_tool_call_arr[index]
                                    .get("arguments", {}),
                                    ensure_ascii=False,
                                )
                                actual_call = parser.multi_format_parser.detectors[
                                    0
                                ].streamed_args_for_tool[index]
                                if latest_delta_len > 0:
                                    actual_call = actual_call[:-latest_delta_len]
                                remaining_call = expected_call.replace(
                                    actual_call, "", 1
                                )
                                call_item.parameters = remaining_call

                            tool_call = ToolCall(
                                id=str(call_item.tool_index),
                                function=FunctionResponse(
                                    name=call_item.name,
                                    arguments=call_item.parameters,
                                ),
                            )
                            choice_data = ChatCompletionResponseStreamChoice(
                                index=index,
                                delta=DeltaMessage(
                                    role="assistant", tool_calls=[tool_call]
                                ),
                                finish_reason="tool_call",
                            )
                            chunk = ChatCompletionStreamResponse(
                                id=content["meta_info"]["id"],
1512
                                created=created,
YAMY's avatar
YAMY committed
1513
1514
1515
1516
                                choices=[choice_data],
                                model=request.model,
                            )
                            yield f"data: {chunk.model_dump_json()}\n\n"
1517

YAMY's avatar
YAMY committed
1518
1519
1520
1521
1522
1523
1524
                        stream_buffers[index] = new_stream_buffer
                        is_firsts[index] = is_first

                    else:
                        # No tool calls => just treat this as normal text
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=index,
1525
                            delta=DeltaMessage(content=delta if delta else None),
YAMY's avatar
YAMY committed
1526
                            finish_reason=(
Xihuai Wang's avatar
Xihuai Wang committed
1527
1528
1529
                                None
                                if finish_reason_type and len(finish_reason_type) == 0
                                else finish_reason_type
YAMY's avatar
YAMY committed
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
                            ),
                            matched_stop=(
                                finish_reason["matched"]
                                if finish_reason and "matched" in finish_reason
                                else None
                            ),
                            logprobs=choice_logprobs,
                        )
                        chunk = ChatCompletionStreamResponse(
                            id=content["meta_info"]["id"],
1540
                            created=created,
YAMY's avatar
YAMY committed
1541
1542
1543
1544
1545
1546
                            choices=[choice_data],
                            model=request.model,
                        )
                        yield f"data: {chunk.model_dump_json()}\n\n"
                        stream_buffers[index] = new_stream_buffer
                        is_firsts[index] = is_first
1547
                if request.stream_options and request.stream_options.include_usage:
1548
1549
1550
1551
1552
1553
1554
1555
                    total_prompt_tokens = sum(
                        tokens
                        for i, tokens in prompt_tokens.items()
                        if i % request.n == 0
                    )
                    total_completion_tokens = sum(
                        tokens for tokens in completion_tokens.values()
                    )
1556
1557
1558
1559
1560
1561
1562
1563
                    cache_report = tokenizer_manager.server_args.enable_cache_report
                    if cache_report:
                        cached_tokens_sum = sum(
                            tokens for tokens in cached_tokens.values()
                        )
                        prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
                    else:
                        prompt_tokens_details = None
1564
                    usage = UsageInfo(
1565
1566
1567
                        prompt_tokens=total_prompt_tokens,
                        completion_tokens=total_completion_tokens,
                        total_tokens=total_prompt_tokens + total_completion_tokens,
1568
                        prompt_tokens_details=prompt_tokens_details,
1569
1570
1571
                    )

                    final_usage_chunk = ChatCompletionStreamResponse(
1572
                        id=content["meta_info"]["id"],
1573
                        created=created,
1574
1575
1576
1577
1578
                        choices=[],
                        model=request.model,
                        usage=usage,
                    )
                    final_usage_data = final_usage_chunk.model_dump_json(
1579
                        exclude_none=True
1580
1581
                    )
                    yield f"data: {final_usage_data}\n\n"
1582
1583
1584
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
1585
1586
            yield "data: [DONE]\n\n"

1587
1588
1589
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
1590
            background=tokenizer_manager.create_abort_task(adapted_request),
1591
        )
1592
1593

    # Non-streaming response.
1594
    try:
1595
        ret = await tokenizer_manager.generate_request(
1596
1597
            adapted_request, raw_request
        ).__anext__()
1598
1599
    except ValueError as e:
        return create_error_response(str(e))
1600
1601
1602
    if not isinstance(ret, list):
        ret = [ret]

1603
    response = v1_chat_generate_response(
YAMY's avatar
YAMY committed
1604
1605
        request,
        ret,
1606
        created,
1607
1608
        cache_report=tokenizer_manager.server_args.enable_cache_report,
        tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
Xihuai Wang's avatar
Xihuai Wang committed
1609
        reasoning_parser=tokenizer_manager.server_args.reasoning_parser,
1610
    )
1611

1612
1613
1614
    return response


1615
def v1_embedding_request(all_requests, tokenizer_manager):
1616
1617
    prompts = []
    sampling_params_list = []
Ying Sheng's avatar
Ying Sheng committed
1618
    first_prompt_type = type(all_requests[0].input)
1619
1620

    for request in all_requests:
Ying Sheng's avatar
Ying Sheng committed
1621
        prompt = request.input
1622
        assert (
1623
            type(prompt) is first_prompt_type
1624
1625
1626
1627
1628
1629
1630
        ), "All prompts must be of the same type in file input settings"
        prompts.append(prompt)

    if len(all_requests) == 1:
        prompt = prompts[0]
        if isinstance(prompt, str) or isinstance(prompt[0], str):
            prompt_kwargs = {"text": prompt}
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
        elif isinstance(prompt, list) and isinstance(
            prompt[0], MultimodalEmbeddingInput
        ):
            assert (
                chat_template_name is not None
            ), "chat_template_name is required for multimodal inputs"
            texts = []
            images = []
            for item in prompt:
                texts.append(item.text if item.text is not None else None)
                images.append(item.image if item.image is not None else None)
            convs = generate_embedding_convs(texts, images, chat_template_name)
            generate_prompts = []
            for conv in convs:
                generate_prompts.append(conv.get_prompt())
            if len(generate_prompts) == 1:
                prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
            else:
                prompt_kwargs = {"text": generate_prompts, "image_data": images}
1650
1651
1652
        else:
            prompt_kwargs = {"input_ids": prompt}
    else:
Baoyuan Qi's avatar
Baoyuan Qi committed
1653
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1654
            prompt_kwargs = {"text": prompts}
1655
1656
1657
1658
1659
1660
1661
        elif isinstance(prompts[0], list) and isinstance(
            prompts[0][0], MultimodalEmbeddingInput
        ):
            # TODO: multiple requests
            raise NotImplementedError(
                "Multiple requests with multimodal inputs are not supported yet"
            )
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
        else:
            prompt_kwargs = {"input_ids": prompts}

    adapted_request = EmbeddingReqInput(
        **prompt_kwargs,
    )

    if len(all_requests) == 1:
        return adapted_request, all_requests[0]
    return adapted_request, all_requests


Ying Sheng's avatar
Ying Sheng committed
1674
1675
1676
def v1_embedding_response(ret, model_path, to_file=False):
    embedding_objects = []
    prompt_tokens = 0
1677
    for idx, ret_item in enumerate(ret):
Ying Sheng's avatar
Ying Sheng committed
1678
1679
1680
        embedding_objects.append(
            EmbeddingObject(
                embedding=ret[idx]["embedding"],
1681
1682
1683
                index=idx,
            )
        )
Ying Sheng's avatar
Ying Sheng committed
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
        prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]

    return EmbeddingResponse(
        data=embedding_objects,
        model=model_path,
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            total_tokens=prompt_tokens,
        ),
    )
1694
1695


1696
async def v1_embeddings(tokenizer_manager, raw_request: Request):
1697
1698
    request_json = await raw_request.json()
    all_requests = [EmbeddingRequest(**request_json)]
1699
    adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
1700
1701

    try:
1702
        ret = await tokenizer_manager.generate_request(
1703
1704
1705
1706
1707
1708
1709
1710
            adapted_request, raw_request
        ).__anext__()
    except ValueError as e:
        return create_error_response(str(e))

    if not isinstance(ret, list):
        ret = [ret]

1711
    response = v1_embedding_response(ret, tokenizer_manager.model_path)
1712
1713
1714
1715

    return response


1716
def to_openai_style_logprobs(
1717
1718
1719
1720
    input_token_logprobs=None,
    output_token_logprobs=None,
    input_top_logprobs=None,
    output_top_logprobs=None,
1721
1722
1723
1724
1725
1726
1727
1728
):
    ret_logprobs = LogProbs()

    def append_token_logprobs(token_logprobs):
        for logprob, _, token_text in token_logprobs:
            ret_logprobs.tokens.append(token_text)
            ret_logprobs.token_logprobs.append(logprob)

1729
            # Not supported yet
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
            ret_logprobs.text_offset.append(-1)

    def append_top_logprobs(top_logprobs):
        for tokens in top_logprobs:
            if tokens is not None:
                ret_logprobs.top_logprobs.append(
                    {token[2]: token[0] for token in tokens}
                )
            else:
                ret_logprobs.top_logprobs.append(None)

1741
1742
1743
1744
1745
1746
1747
1748
    if input_token_logprobs is not None:
        append_token_logprobs(input_token_logprobs)
    if output_token_logprobs is not None:
        append_token_logprobs(output_token_logprobs)
    if input_top_logprobs is not None:
        append_top_logprobs(input_top_logprobs)
    if output_top_logprobs is not None:
        append_top_logprobs(output_top_logprobs)
1749

Liangsheng Yin's avatar
Liangsheng Yin committed
1750
    return ret_logprobs