adapter.py 53.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""Conversion between OpenAI APIs and native SRT APIs"""
Liangsheng Yin's avatar
Liangsheng Yin committed
15

16
import asyncio
17
import json
18
import logging
19
import os
20
21
import time
import uuid
22
from http import HTTPStatus
23
from typing import Dict, List
24

25
from fastapi import HTTPException, Request, UploadFile
26
from fastapi.responses import ORJSONResponse, StreamingResponse
27
from pydantic import ValidationError
28

29
30
31
32
33
34
35
try:
    from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError:
    # Before outlines 0.0.47, convert_json_schema_to_str is under
    # outlines.integrations.utils
    from outlines.integrations.utils import convert_json_schema_to_str

36
37
38
39
40
41
42
from sglang.srt.conversation import (
    Conversation,
    SeparatorStyle,
    chat_template_exists,
    generate_chat_conv,
    register_conv_template,
)
Ying Sheng's avatar
Ying Sheng committed
43
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
Mingyi's avatar
Mingyi committed
44
from sglang.srt.openai_api.protocol import (
45
46
    BatchRequest,
    BatchResponse,
47
48
49
50
51
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
52
    ChatCompletionTokenLogprob,
53
    ChatMessage,
54
    ChoiceLogprobs,
55
56
57
58
59
60
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    DeltaMessage,
Ying Sheng's avatar
Ying Sheng committed
61
    EmbeddingObject,
62
63
    EmbeddingRequest,
    EmbeddingResponse,
64
    ErrorResponse,
65
    FileDeleteResponse,
66
67
    FileRequest,
    FileResponse,
68
    LogProbs,
69
    TopLogprob,
70
71
    UsageInfo,
)
72
from sglang.utils import get_exception_traceback
73

74
75
logger = logging.getLogger(__name__)

76
77
chat_template_name = None

Liangsheng Yin's avatar
Liangsheng Yin committed
78

79
80
81
82
83
84
85
86
87
88
class FileMetadata:
    def __init__(self, filename: str, purpose: str):
        self.filename = filename
        self.purpose = purpose


# In-memory storage for batch jobs and files
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
89
# map file id to file path in SGLang backend
90
91
92
93
94
95
96
file_id_storage: Dict[str, str] = {}


# backend storage directory
storage_dir = None


97
98
99
def create_error_response(
    message: str,
    err_type: str = "BadRequestError",
100
101
102
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
103
    return ORJSONResponse(content=error.model_dump(), status_code=error.code)
104
105
106
107
108


def create_streaming_error_response(
    message: str,
    err_type: str = "BadRequestError",
109
110
111
    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
    error = ErrorResponse(message=message, type=err_type, code=status_code.value)
112
113
114
115
    json_str = json.dumps({"error": error.model_dump()})
    return json_str


116
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
117
118
    global chat_template_name

119
120
121
    logger.info(
        f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
    )
122
123
124
125
126
127
    if not chat_template_exists(chat_template_arg):
        if not os.path.exists(chat_template_arg):
            raise RuntimeError(
                f"Chat template {chat_template_arg} is not a built-in template name "
                "or a valid chat template file path."
            )
128
129
130
131
132
        if chat_template_arg.endswith(".jinja"):
            with open(chat_template_arg, "r") as f:
                chat_template = "".join(f.readlines()).strip("\n")
            tokenizer_manager.tokenizer.chat_template = chat_template.replace(
                "\\n", "\n"
133
            )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
            chat_template_name = None
        else:
            assert chat_template_arg.endswith(
                ".json"
            ), "unrecognized format of chat template file"
            with open(chat_template_arg, "r") as filep:
                template = json.load(filep)
                try:
                    sep_style = SeparatorStyle[template["sep_style"]]
                except KeyError:
                    raise ValueError(
                        f"Unknown separator style: {template['sep_style']}"
                    ) from None
                register_conv_template(
                    Conversation(
                        name=template["name"],
                        system_template=template["system"] + "\n{system_message}",
                        system_message=template.get("system_message", ""),
                        roles=(template["user"], template["assistant"]),
                        sep_style=sep_style,
                        sep=template.get("sep", "\n"),
                        stop_str=template["stop_str"],
                    ),
                    override=True,
                )
            chat_template_name = template["name"]
160
161
162
163
    else:
        chat_template_name = chat_template_arg


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
    try:
        global storage_dir
        if file_storage_pth:
            storage_dir = file_storage_pth
        # Read the file content
        file_content = await file.read()

        # Create an instance of RequestBody
        request_body = FileRequest(file=file_content, purpose=purpose)

        # Save the file to the sglang_oai_storage directory
        os.makedirs(storage_dir, exist_ok=True)
        file_id = f"backend_input_file-{uuid.uuid4()}"
        filename = f"{file_id}.jsonl"
        file_path = os.path.join(storage_dir, filename)

        with open(file_path, "wb") as f:
            f.write(request_body.file)

        # add info to global file map
        file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
        file_id_storage[file_id] = file_path

        # Return the response in the required format
        response = FileResponse(
            id=file_id,
            bytes=len(request_body.file),
            created_at=int(time.time()),
            filename=file.filename,
            purpose=request_body.purpose,
        )
        file_id_response[file_id] = response

        return response
    except ValidationError as e:
        return {"error": "Invalid input", "details": e.errors()}


203
204
205
206
207
208
209
210
211
212
213
214
215
216
async def v1_delete_file(file_id: str):
    # Retrieve the file job from the in-memory storage
    file_response = file_id_response.get(file_id)
    if file_response is None:
        raise HTTPException(status_code=404, detail="File not found")
    file_path = file_id_storage.get(file_id)
    if file_path is None:
        raise HTTPException(status_code=404, detail="File not found")
    os.remove(file_path)
    del file_id_response[file_id]
    del file_id_storage[file_id]
    return FileDeleteResponse(id=file_id, deleted=True)


217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
async def v1_batches(tokenizer_manager, raw_request: Request):
    try:
        body = await raw_request.json()

        batch_request = BatchRequest(**body)

        batch_id = f"batch_{uuid.uuid4()}"

        # Create an instance of BatchResponse
        batch_response = BatchResponse(
            id=batch_id,
            endpoint=batch_request.endpoint,
            input_file_id=batch_request.input_file_id,
            completion_window=batch_request.completion_window,
            created_at=int(time.time()),
            metadata=batch_request.metadata,
        )

        batch_storage[batch_id] = batch_response

        # Start processing the batch asynchronously
        asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))

        # Return the initial batch_response
        return batch_response

    except ValidationError as e:
        return {"error": "Invalid input", "details": e.errors()}
    except Exception as e:
        return {"error": str(e)}


async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
    try:
        # Update the batch status to "in_progress"
        batch_storage[batch_id].status = "in_progress"
        batch_storage[batch_id].in_progress_at = int(time.time())

        # Retrieve the input file content
        input_file_request = file_id_request.get(batch_request.input_file_id)
        if not input_file_request:
            raise ValueError("Input file not found")

        # Parse the JSONL file and process each request
        input_file_path = file_id_storage.get(batch_request.input_file_id)
        with open(input_file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()

        total_requests = len(lines)
        completed_requests = 0
        failed_requests = 0

        all_ret = []
        end_point = batch_storage[batch_id].endpoint
        file_request_list = []
        all_requests = []
273
        request_ids = []
274
275
276
277
        for line in lines:
            request_data = json.loads(line)
            file_request_list.append(request_data)
            body = request_data["body"]
278
            request_ids.append(request_data["custom_id"])
279
280
281
282
283
284

            # Although streaming is supported for standalone completions, it is not supported in
            # batch mode (multiple completions in single request).
            if body.get("stream", False):
                raise ValueError("Streaming requests are not supported in batch mode")

285
286
287
288
            if end_point == "/v1/chat/completions":
                all_requests.append(ChatCompletionRequest(**body))
            elif end_point == "/v1/completions":
                all_requests.append(CompletionRequest(**body))
289

290
291
        if end_point == "/v1/chat/completions":
            adapted_request, request = v1_chat_generate_request(
292
                all_requests, tokenizer_manager, request_ids=request_ids
293
294
            )
        elif end_point == "/v1/completions":
295
296
297
298
            adapted_request, request = v1_generate_request(
                all_requests, request_ids=request_ids
            )

299
300
301
302
303
        try:
            ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
            if not isinstance(ret, list):
                ret = [ret]
            if end_point == "/v1/chat/completions":
304
305
306
307
308
309
                responses = v1_chat_generate_response(
                    request,
                    ret,
                    to_file=True,
                    cache_report=tokenizer_manager.server_args.enable_cache_report,
                )
310
            else:
yichuan~'s avatar
yichuan~ committed
311
312
313
                responses = v1_generate_response(
                    request, ret, tokenizer_manager, to_file=True
                )
314
315

        except Exception as e:
316
317
            logger.error(f"error: {get_exception_traceback()}")
            responses = []
318
319
320
321
322
323
324
325
326
327
            error_json = {
                "id": f"batch_req_{uuid.uuid4()}",
                "custom_id": request_data.get("custom_id"),
                "response": None,
                "error": {"message": str(e)},
            }
            all_ret.append(error_json)
            failed_requests += len(file_request_list)

        for idx, response in enumerate(responses):
328
            # the batch_req here can be changed to be named within a batch granularity
329
330
331
332
333
334
335
336
            response_json = {
                "id": f"batch_req_{uuid.uuid4()}",
                "custom_id": file_request_list[idx].get("custom_id"),
                "response": response,
                "error": None,
            }
            all_ret.append(response_json)
            completed_requests += 1
337

338
339
340
341
342
343
344
345
346
347
348
349
        # Write results to a new file
        output_file_id = f"backend_result_file-{uuid.uuid4()}"
        global storage_dir
        output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
        with open(output_file_path, "w", encoding="utf-8") as f:
            for ret in all_ret:
                f.write(json.dumps(ret) + "\n")

        # Update batch response with output file information
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.output_file_id = output_file_id
        file_id_storage[output_file_id] = output_file_path
350
351
352
353
354
355
356
        file_id_response[output_file_id] = FileResponse(
            id=output_file_id,
            bytes=os.path.getsize(output_file_path),
            created_at=int(time.time()),
            filename=f"{output_file_id}.jsonl",
            purpose="batch_result",
        )
357
358
359
360
361
362
363
364
365
366
        # Update batch status to "completed"
        retrieve_batch.status = "completed"
        retrieve_batch.completed_at = int(time.time())
        retrieve_batch.request_counts = {
            "total": total_requests,
            "completed": completed_requests,
            "failed": failed_requests,
        }

    except Exception as e:
367
        logger.error(f"error: {e}")
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        # Update batch status to "failed"
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "failed"
        retrieve_batch.failed_at = int(time.time())
        retrieve_batch.errors = {"message": str(e)}


async def v1_retrieve_batch(batch_id: str):
    # Retrieve the batch job from the in-memory storage
    batch_response = batch_storage.get(batch_id)
    if batch_response is None:
        raise HTTPException(status_code=404, detail="Batch not found")

    return batch_response


384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
    # Retrieve the batch job from the in-memory storage
    batch_response = batch_storage.get(batch_id)
    if batch_response is None:
        raise HTTPException(status_code=404, detail="Batch not found")

    # Only do cancal when status is "validating" or "in_progress"
    if batch_response.status in ["validating", "in_progress"]:
        # Start cancelling the batch asynchronously
        asyncio.create_task(
            cancel_batch(
                tokenizer_manager=tokenizer_manager,
                batch_id=batch_id,
                input_file_id=batch_response.input_file_id,
            )
        )

        # Update batch status to "cancelling"
        batch_response.status = "cancelling"

        return batch_response
    else:
        raise HTTPException(
            status_code=500,
            detail=f"Current status is {batch_response.status}, no need to cancel",
        )


async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
    try:
        # Update the batch status to "cancelling"
        batch_storage[batch_id].status = "cancelling"

        # Retrieve the input file content
        input_file_request = file_id_request.get(input_file_id)
        if not input_file_request:
            raise ValueError("Input file not found")

        # Parse the JSONL file and process each request
        input_file_path = file_id_storage.get(input_file_id)
        with open(input_file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()

        file_request_list = []
        request_ids = []
        for line in lines:
            request_data = json.loads(line)
            file_request_list.append(request_data)
            request_ids.append(request_data["custom_id"])

        # Cancel requests by request_ids
        for rid in request_ids:
            tokenizer_manager.abort_request(rid=rid)

        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "cancelled"

    except Exception as e:
        logger.error("error in SGLang:", e)
        # Update batch status to "failed"
        retrieve_batch = batch_storage[batch_id]
        retrieve_batch.status = "failed"
        retrieve_batch.failed_at = int(time.time())
        retrieve_batch.errors = {"message": str(e)}


450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
async def v1_retrieve_file(file_id: str):
    # Retrieve the batch job from the in-memory storage
    file_response = file_id_response.get(file_id)
    if file_response is None:
        raise HTTPException(status_code=404, detail="File not found")
    return file_response


async def v1_retrieve_file_content(file_id: str):
    file_pth = file_id_storage.get(file_id)
    if not file_pth or not os.path.exists(file_pth):
        raise HTTPException(status_code=404, detail="File not found")

    def iter_file():
        with open(file_pth, mode="rb") as file_like:
            yield from file_like

    return StreamingResponse(iter_file(), media_type="application/octet-stream")


470
471
472
def v1_generate_request(
    all_requests: List[CompletionRequest], request_ids: List[str] = None
):
473
474
475
476
477
478
479
480
481
482
483
    if len(all_requests) > 1:
        first_prompt_type = type(all_requests[0].prompt)
        for request in all_requests:
            assert (
                type(request.prompt) is first_prompt_type
            ), "All prompts must be of the same type in file input settings"
            if request.n > 1:
                raise ValueError(
                    "Parallel sampling is not supported for completions from files"
                )

484
485
    prompts = []
    sampling_params_list = []
486
    return_logprobs = []
487
    logprob_start_lens = []
488
    top_logprobs_nums = []
489
    lora_paths = []
yichuan~'s avatar
yichuan~ committed
490

491
    for request in all_requests:
492
        # NOTE: with openai API, the prompt's logprobs are always not computed
493
        if request.echo and request.logprobs:
494
            logger.warning(
495
                "Echo is not compatible with logprobs. "
496
                "To compute logprobs of input prompt, please use the native /generate API."
497
498
499
            )

        prompts.append(request.prompt)
500
        lora_paths.append(request.lora_path)
501
502
503
504
        if request.echo and request.logprobs:
            current_logprob_start_len = 0
        else:
            current_logprob_start_len = -1
505
506
507
508
509
510
511
512
        sampling_params_list.append(
            {
                "temperature": request.temperature,
                "max_new_tokens": request.max_tokens,
                "min_new_tokens": request.min_tokens,
                "stop": request.stop,
                "stop_token_ids": request.stop_token_ids,
                "top_p": request.top_p,
513
514
                "top_k": request.top_k,
                "min_p": request.min_p,
515
516
517
518
519
                "presence_penalty": request.presence_penalty,
                "frequency_penalty": request.frequency_penalty,
                "repetition_penalty": request.repetition_penalty,
                "regex": request.regex,
                "json_schema": request.json_schema,
520
                "ebnf": request.ebnf,
521
522
                "n": request.n,
                "no_stop_trim": request.no_stop_trim,
523
524
                "ignore_eos": request.ignore_eos,
                "skip_special_tokens": request.skip_special_tokens,
525
526
            }
        )
527
        return_logprobs.append(request.logprobs is not None)
528
        logprob_start_lens.append(current_logprob_start_len)
529
530
531
        top_logprobs_nums.append(
            request.logprobs if request.logprobs is not None else 0
        )
532
533

    if len(all_requests) == 1:
534
535
536
537
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
            prompt_kwargs = {"text": prompts[0]}
        else:
            prompt_kwargs = {"input_ids": prompts[0]}
538
        sampling_params_list = sampling_params_list[0]
539
        return_logprobs = return_logprobs[0]
540
        logprob_start_lens = logprob_start_lens[0]
541
        top_logprobs_nums = top_logprobs_nums[0]
542
        lora_paths = lora_paths[0]
543
    else:
544
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
545
546
547
            prompt_kwargs = {"text": prompts}
        else:
            prompt_kwargs = {"input_ids": prompts}
yichuan~'s avatar
yichuan~ committed
548

549
    adapted_request = GenerateReqInput(
550
        **prompt_kwargs,
551
        sampling_params=sampling_params_list,
552
553
        return_logprob=return_logprobs,
        top_logprobs_num=top_logprobs_nums,
554
        logprob_start_len=logprob_start_lens,
555
        return_text_in_logprobs=True,
556
        stream=all_requests[0].stream,
557
        rid=request_ids,
558
        lora_path=lora_paths,
559
    )
yichuan~'s avatar
yichuan~ committed
560

561
    return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
562
563


yichuan~'s avatar
yichuan~ committed
564
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
565
566
567
    choices = []
    echo = False

yichuan~'s avatar
yichuan~ committed
568
    if (not isinstance(request, list)) and request.echo:
569
        # TODO: handle the case propmt is token ids
yichuan~'s avatar
yichuan~ committed
570
571
        if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
            # for the case of multiple str prompts
572
            prompts = request.prompt
yichuan~'s avatar
yichuan~ committed
573
574
575
576
577
578
579
580
581
582
583
584
585
        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
            # for the case of multiple token ids prompts
            prompts = [
                tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
                for prompt in request.prompt
            ]
        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
            # for the case of single token ids prompt
            prompts = [
                tokenizer_manager.tokenizer.decode(
                    request.prompt, skip_special_tokens=True
                )
            ]
586
        else:
yichuan~'s avatar
yichuan~ committed
587
            # for the case of single str prompt
588
589
590
591
592
            prompts = [request.prompt]
        echo = True

    for idx, ret_item in enumerate(ret):
        text = ret_item["text"]
yichuan~'s avatar
yichuan~ committed
593
        if isinstance(request, list) and request[idx].echo:
594
595
            echo = True
            text = request[idx].prompt + text
596
        if echo and not isinstance(request, list):
yichuan~'s avatar
yichuan~ committed
597
598
            prompt_index = idx // request.n
            text = prompts[prompt_index] + text
599
600

        logprobs = False
601
        if isinstance(request, list) and request[idx].logprobs is not None:
602
            logprobs = True
603
        elif (not isinstance(request, list)) and request.logprobs is not None:
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            logprobs = True
        if logprobs:
            if echo:
                input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
                input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
            else:
                input_token_logprobs = None
                input_top_logprobs = None

            logprobs = to_openai_style_logprobs(
                input_token_logprobs=input_token_logprobs,
                input_top_logprobs=input_top_logprobs,
                output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
                output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
            )
        else:
            logprobs = None

622
623
        finish_reason = ret_item["meta_info"]["finish_reason"]

624
        if to_file:
625
            # to make the choise data json serializable
626
627
628
629
            choice_data = {
                "index": 0,
                "text": text,
                "logprobs": logprobs,
630
631
632
633
634
                "finish_reason": (finish_reason["type"] if finish_reason else ""),
                "matched_stop": (
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
635
                ),
636
637
638
639
640
641
            }
        else:
            choice_data = CompletionResponseChoice(
                index=idx,
                text=text,
                logprobs=logprobs,
642
643
644
645
646
                finish_reason=(finish_reason["type"] if finish_reason else ""),
                matched_stop=(
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
647
                ),
648
649
650
651
652
653
654
655
656
657
658
            )

        choices.append(choice_data)

    if to_file:
        responses = []
        for i, choice in enumerate(choices):
            response = {
                "status_code": 200,
                "request_id": ret[i]["meta_info"]["id"],
                "body": {
659
                    # remain the same but if needed we can change that
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
                    "id": ret[i]["meta_info"]["id"],
                    "object": "text_completion",
                    "created": int(time.time()),
                    "model": request[i].model,
                    "choices": choice,
                    "usage": {
                        "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
                        "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
                        "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
                        + ret[i]["meta_info"]["completion_tokens"],
                    },
                    "system_fingerprint": None,
                },
            }
            responses.append(response)
        return responses
    else:
677
678
679
        prompt_tokens = sum(
            ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
        )
680
681
682
683
684
685
        completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
        response = CompletionResponse(
            id=ret[0]["meta_info"]["id"],
            model=request.model,
            choices=choices,
            usage=UsageInfo(
yichuan~'s avatar
yichuan~ committed
686
                prompt_tokens=prompt_tokens,
687
                completion_tokens=completion_tokens,
yichuan~'s avatar
yichuan~ committed
688
                total_tokens=prompt_tokens + completion_tokens,
689
690
691
692
693
694
695
            ),
        )
    return response


async def v1_completions(tokenizer_manager, raw_request: Request):
    request_json = await raw_request.json()
696
697
698
699
700
701
702
703
    if "extra_body" in request_json:
        extra = request_json["extra_body"]
        if "ebnf" in extra:
            request_json["ebnf"] = extra["ebnf"]
        if "regex" in extra:
            request_json["regex"] = extra["regex"]
        # remove extra_body to avoid pydantic conflict
        del request_json["extra_body"]
704
705
    all_requests = [CompletionRequest(**request_json)]
    adapted_request, request = v1_generate_request(all_requests)
706
707
708
709

    if adapted_request.stream:

        async def generate_stream_resp():
710
711
712
713
            stream_buffers = {}
            n_prev_tokens = {}
            prompt_tokens = {}
            completion_tokens = {}
714
715
            try:
                async for content in tokenizer_manager.generate_request(
716
717
                    adapted_request, raw_request
                ):
718
                    index = content.get("index", 0)
719
720
721
722

                    stream_buffer = stream_buffers.get(index, "")
                    n_prev_token = n_prev_tokens.get(index, 0)

723
                    text = content["text"]
724
725
                    prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
                    completion_tokens[index] = content["meta_info"]["completion_tokens"]
726
727
728

                    if not stream_buffer:  # The first chunk
                        if request.echo:
yichuan~'s avatar
yichuan~ committed
729
730
731
                            if isinstance(request.prompt, str):
                                # for the case of single str prompts
                                prompts = request.prompt
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
                            elif isinstance(request.prompt, list):
                                if isinstance(request.prompt[0], str):
                                    # for the case of multiple str prompts
                                    prompts = request.prompt[index // request.n]
                                elif isinstance(request.prompt[0], int):
                                    # for the case of single token ids prompt
                                    prompts = tokenizer_manager.tokenizer.decode(
                                        request.prompt, skip_special_tokens=True
                                    )
                                elif isinstance(request.prompt[0], list) and isinstance(
                                    request.prompt[0][0], int
                                ):
                                    # for the case of multiple token ids prompts
                                    prompts = tokenizer_manager.tokenizer.decode(
                                        request.prompt[index // request.n],
                                        skip_special_tokens=True,
                                    )
yichuan~'s avatar
yichuan~ committed
749

750
                            # Prepend prompt in response text.
yichuan~'s avatar
yichuan~ committed
751
                            text = prompts + text
752

753
                    if request.logprobs is not None:
754
755
                        # The first chunk and echo is enabled.
                        if not stream_buffer and request.echo:
756
757
                            input_token_logprobs = content["meta_info"][
                                "input_token_logprobs"
758
                            ]
759
760
                            input_top_logprobs = content["meta_info"][
                                "input_top_logprobs"
761
762
                            ]
                        else:
763
764
                            input_token_logprobs = None
                            input_top_logprobs = None
765
766

                        logprobs = to_openai_style_logprobs(
767
768
769
770
                            input_token_logprobs=input_token_logprobs,
                            input_top_logprobs=input_top_logprobs,
                            output_token_logprobs=content["meta_info"][
                                "output_token_logprobs"
771
                            ][n_prev_token:],
772
773
                            output_top_logprobs=content["meta_info"][
                                "output_top_logprobs"
774
                            ][n_prev_token:],
775
                        )
776
                        n_prev_token = len(
777
                            content["meta_info"]["output_token_logprobs"]
778
                        )
779
                    else:
780
                        logprobs = None
781

782
                    delta = text[len(stream_buffer) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
783
                    stream_buffer = stream_buffer + delta
784
                    finish_reason = content["meta_info"]["finish_reason"]
785
                    choice_data = CompletionResponseStreamChoice(
786
                        index=index,
787
788
                        text=delta,
                        logprobs=logprobs,
789
790
791
792
793
                        finish_reason=(finish_reason["type"] if finish_reason else ""),
                        matched_stop=(
                            finish_reason["matched"]
                            if finish_reason and "matched" in finish_reason
                            else None
794
                        ),
795
796
797
798
799
800
801
                    )
                    chunk = CompletionStreamResponse(
                        id=content["meta_info"]["id"],
                        object="text_completion",
                        choices=[choice_data],
                        model=request.model,
                    )
802
803
804
805

                    stream_buffers[index] = stream_buffer
                    n_prev_tokens[index] = n_prev_token

806
                    yield f"data: {chunk.model_dump_json()}\n\n"
807
                if request.stream_options and request.stream_options.include_usage:
808
809
810
811
812
813
814
815
                    total_prompt_tokens = sum(
                        tokens
                        for i, tokens in prompt_tokens.items()
                        if i % request.n == 0
                    )
                    total_completion_tokens = sum(
                        tokens for tokens in completion_tokens.values()
                    )
816
                    usage = UsageInfo(
817
818
819
                        prompt_tokens=total_prompt_tokens,
                        completion_tokens=total_completion_tokens,
                        total_tokens=total_prompt_tokens + total_completion_tokens,
820
821
822
823
824
825
826
827
828
829
830
831
                    )

                    final_usage_chunk = CompletionStreamResponse(
                        id=str(uuid.uuid4().hex),
                        choices=[],
                        model=request.model,
                        usage=usage,
                    )
                    final_usage_data = final_usage_chunk.model_dump_json(
                        exclude_unset=True, exclude_none=True
                    )
                    yield f"data: {final_usage_data}\n\n"
832
833
834
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
835
836
            yield "data: [DONE]\n\n"

837
838
839
840
841
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(adapted_request),
        )
842
843

    # Non-streaming response.
844
845
    try:
        ret = await tokenizer_manager.generate_request(
846
847
            adapted_request, raw_request
        ).__anext__()
848
849
    except ValueError as e:
        return create_error_response(str(e))
850

851
852
853
    if not isinstance(ret, list):
        ret = [ret]

yichuan~'s avatar
yichuan~ committed
854
    response = v1_generate_response(request, ret, tokenizer_manager)
855
    return response
856

857

858
def v1_chat_generate_request(
859
860
861
    all_requests: List[ChatCompletionRequest],
    tokenizer_manager,
    request_ids: List[str] = None,
862
):
863
    input_ids = []
864
865
    sampling_params_list = []
    image_data_list = []
866
    return_logprobs = []
867
    logprob_start_lens = []
868
    top_logprobs_nums = []
869
    modalities_list = []
870
    lora_paths = []
871
872
873

    # NOTE: with openai API, the prompt's logprobs are always not computed

874
875
876
877
878
879
880
881
882
    for request in all_requests:
        # Prep the data needed for the underlying GenerateReqInput:
        #  - prompt: The full prompt string.
        #  - stop: Custom stop tokens.
        #  - image_data: None or a list of image strings (URLs or base64 strings).
        #    None skips any image processing in GenerateReqInput.
        if not isinstance(request.messages, str):
            # Apply chat template and its stop strings.
            if chat_template_name is None:
883
884
885
886
887
888
889
890
891
892
893
894
895
                openai_compatible_messages = []
                for message in request.messages:
                    if isinstance(message.content, str):
                        openai_compatible_messages.append(
                            {"role": message.role, "content": message.content}
                        )
                    else:
                        content_list = message.dict()["content"]
                        for content in content_list:
                            if content["type"] == "text":
                                openai_compatible_messages.append(
                                    {"role": message.role, "content": content["text"]}
                                )
896
897
898
899
900
                if openai_compatible_messages[-1]["role"] == "assistant":
                    assistant_prefix = openai_compatible_messages[-1]["content"]
                    openai_compatible_messages = openai_compatible_messages[:-1]
                else:
                    assistant_prefix = None
901
                prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
902
903
904
                    openai_compatible_messages,
                    tokenize=True,
                    add_generation_prompt=True,
905
                )
906
907
                if assistant_prefix:
                    prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
908
909
                stop = request.stop
                image_data = None
910
                modalities = []
911
            else:
912
913
914
                conv = generate_chat_conv(request, chat_template_name)
                prompt = conv.get_prompt()
                image_data = conv.image_data
915
                modalities = conv.modalities
916
917
918
919
920
921
                stop = conv.stop_str or []
                if request.stop:
                    if isinstance(request.stop, str):
                        stop.append(request.stop)
                    else:
                        stop.extend(request.stop)
922
                prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
923
        else:
924
            # Use the raw prompt and stop strings if the messages is already a string.
yichuan~'s avatar
yichuan~ committed
925
            prompt_ids = request.messages
926
927
            stop = request.stop
            image_data = None
928
            modalities = []
929
        input_ids.append(prompt_ids)
930
        return_logprobs.append(request.logprobs)
931
        logprob_start_lens.append(-1)
932
        top_logprobs_nums.append(request.top_logprobs or 0)
933
        lora_paths.append(request.lora_path)
934
935
936
937
938
939
940
941

        sampling_params = {
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "min_new_tokens": request.min_tokens,
            "stop": stop,
            "stop_token_ids": request.stop_token_ids,
            "top_p": request.top_p,
942
943
            "top_k": request.top_k,
            "min_p": request.min_p,
944
945
946
947
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
            "repetition_penalty": request.repetition_penalty,
            "regex": request.regex,
948
            "ebnf": request.ebnf,
949
            "n": request.n,
950
            "no_stop_trim": request.no_stop_trim,
951
            "ignore_eos": request.ignore_eos,
952
            "skip_special_tokens": request.skip_special_tokens,
953
954
955
956
957
958
959
        }
        if request.response_format and request.response_format.type == "json_schema":
            sampling_params["json_schema"] = convert_json_schema_to_str(
                request.response_format.json_schema.schema_
            )
        sampling_params_list.append(sampling_params)

960
        image_data_list.append(image_data)
961
        modalities_list.append(modalities)
962
    if len(all_requests) == 1:
963
964
        if isinstance(input_ids[0], str):
            prompt_kwargs = {"text": input_ids[0]}
yichuan~'s avatar
yichuan~ committed
965
        else:
966
            prompt_kwargs = {"input_ids": input_ids[0]}
967
        sampling_params_list = sampling_params_list[0]
968
        image_data_list = image_data_list[0]
969
        return_logprobs = return_logprobs[0]
970
        logprob_start_lens = logprob_start_lens[0]
971
        top_logprobs_nums = top_logprobs_nums[0]
972
        modalities_list = modalities_list[0]
973
        lora_paths = lora_paths[0]
yichuan~'s avatar
yichuan~ committed
974
975
976
977
978
    else:
        if isinstance(input_ids[0], str):
            prompt_kwargs = {"text": input_ids}
        else:
            prompt_kwargs = {"input_ids": input_ids}
979

980
    adapted_request = GenerateReqInput(
yichuan~'s avatar
yichuan~ committed
981
        **prompt_kwargs,
982
        image_data=image_data_list,
983
        sampling_params=sampling_params_list,
984
        return_logprob=return_logprobs,
985
        logprob_start_len=logprob_start_lens,
986
987
988
        top_logprobs_num=top_logprobs_nums,
        stream=all_requests[0].stream,
        return_text_in_logprobs=True,
989
        rid=request_ids,
990
        modalities=modalities_list,
991
        lora_path=lora_paths,
992
    )
993
994

    return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
995

996

997
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
998
999
1000
    choices = []

    for idx, ret_item in enumerate(ret):
1001
        logprobs = False
yichuan~'s avatar
yichuan~ committed
1002
        if isinstance(request, list) and request[idx].logprobs:
1003
            logprobs = True
yichuan~'s avatar
yichuan~ committed
1004
        elif (not isinstance(request, list)) and request.logprobs:
1005
1006
1007
1008
1009
1010
1011
            logprobs = True
        if logprobs:
            logprobs = to_openai_style_logprobs(
                output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
                output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
            )
            token_logprobs = []
1012
1013
1014
            for token_idx, (token, logprob) in enumerate(
                zip(logprobs.tokens, logprobs.token_logprobs)
            ):
1015
1016
1017
                token_bytes = list(token.encode("utf-8"))
                top_logprobs = []
                if logprobs.top_logprobs:
1018
1019
1020
                    for top_token, top_logprob in logprobs.top_logprobs[
                        token_idx
                    ].items():
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
                        top_token_bytes = list(top_token.encode("utf-8"))
                        top_logprobs.append(
                            TopLogprob(
                                token=top_token,
                                bytes=top_token_bytes,
                                logprob=top_logprob,
                            )
                        )
                token_logprobs.append(
                    ChatCompletionTokenLogprob(
                        token=token,
                        bytes=token_bytes,
                        logprob=logprob,
                        top_logprobs=top_logprobs,
                    )
                )

            choice_logprobs = ChoiceLogprobs(content=token_logprobs)
        else:
            choice_logprobs = None
1041

1042
1043
        finish_reason = ret_item["meta_info"]["finish_reason"]

1044
        if to_file:
1045
            # to make the choice data json serializable
1046
1047
1048
            choice_data = {
                "index": 0,
                "message": {"role": "assistant", "content": ret_item["text"]},
1049
                "logprobs": choice_logprobs,
1050
1051
1052
1053
1054
                "finish_reason": (finish_reason["type"] if finish_reason else ""),
                "matched_stop": (
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
1055
                ),
1056
            }
1057
        else:
1058
1059
1060
            choice_data = ChatCompletionResponseChoice(
                index=idx,
                message=ChatMessage(role="assistant", content=ret_item["text"]),
1061
                logprobs=choice_logprobs,
1062
1063
1064
1065
1066
                finish_reason=(finish_reason["type"] if finish_reason else ""),
                matched_stop=(
                    finish_reason["matched"]
                    if finish_reason and "matched" in finish_reason
                    else None
1067
                ),
1068
1069
1070
            )

        choices.append(choice_data)
1071

1072
1073
1074
1075
1076
1077
1078
1079
    if to_file:
        responses = []

        for i, choice in enumerate(choices):
            response = {
                "status_code": 200,
                "request_id": ret[i]["meta_info"]["id"],
                "body": {
1080
                    # remain the same but if needed we can change that
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                    "id": ret[i]["meta_info"]["id"],
                    "object": "chat.completion",
                    "created": int(time.time()),
                    "model": request[i].model,
                    "choices": choice,
                    "usage": {
                        "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
                        "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
                        "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
                        + ret[i]["meta_info"]["completion_tokens"],
                    },
                    "system_fingerprint": None,
                },
            }
            responses.append(response)
        return responses
1097
    else:
1098
1099
1100
1101
        prompt_tokens = sum(
            ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
        )
        completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
1102
        cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
1103
1104
1105
1106
1107
        response = ChatCompletionResponse(
            id=ret[0]["meta_info"]["id"],
            model=request.model,
            choices=choices,
            usage=UsageInfo(
1108
1109
1110
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
1111
1112
1113
                prompt_tokens_details=(
                    {"cached_tokens": cached_tokens} if cache_report else None
                ),
1114
1115
1116
            ),
        )
        return response
1117

1118
1119
1120

async def v1_chat_completions(tokenizer_manager, raw_request: Request):
    request_json = await raw_request.json()
1121
1122
1123
1124
1125
1126
1127
1128
1129
    if "extra_body" in request_json:
        extra = request_json["extra_body"]
        # For example, if 'ebnf' is given:
        if "ebnf" in extra:
            request_json["ebnf"] = extra["ebnf"]
        if "regex" in extra:
            request_json["regex"] = extra["regex"]
        # remove extra_body to avoid pydantic conflict
        del request_json["extra_body"]
1130
1131
    all_requests = [ChatCompletionRequest(**request_json)]
    adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1132
1133
1134
1135

    if adapted_request.stream:

        async def generate_stream_resp():
1136
1137
1138
1139
1140
            is_firsts = {}
            stream_buffers = {}
            n_prev_tokens = {}
            prompt_tokens = {}
            completion_tokens = {}
1141
            try:
1142
1143
1144
                async for content in tokenizer_manager.generate_request(
                    adapted_request, raw_request
                ):
1145
                    index = content.get("index", 0)
1146
1147
1148
1149
1150
1151
1152

                    is_first = is_firsts.get(index, True)
                    stream_buffer = stream_buffers.get(index, "")
                    n_prev_token = n_prev_tokens.get(index, 0)

                    prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
                    completion_tokens[index] = content["meta_info"]["completion_tokens"]
yichuan~'s avatar
yichuan~ committed
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
                    if request.logprobs:
                        logprobs = to_openai_style_logprobs(
                            output_token_logprobs=content["meta_info"][
                                "output_token_logprobs"
                            ][n_prev_token:],
                            output_top_logprobs=content["meta_info"][
                                "output_top_logprobs"
                            ][n_prev_token:],
                        )

                        n_prev_token = len(
                            content["meta_info"]["output_token_logprobs"]
                        )
                        token_logprobs = []
                        for token, logprob in zip(
                            logprobs.tokens, logprobs.token_logprobs
                        ):
                            token_bytes = list(token.encode("utf-8"))
                            top_logprobs = []
                            if logprobs.top_logprobs:
                                for top_token, top_logprob in logprobs.top_logprobs[
                                    0
                                ].items():
                                    top_token_bytes = list(top_token.encode("utf-8"))
                                    top_logprobs.append(
                                        TopLogprob(
                                            token=top_token,
                                            bytes=top_token_bytes,
                                            logprob=top_logprob,
                                        )
                                    )
                            token_logprobs.append(
                                ChatCompletionTokenLogprob(
                                    token=token,
                                    bytes=token_bytes,
                                    logprob=logprob,
                                    top_logprobs=top_logprobs,
                                )
                            )

                        choice_logprobs = ChoiceLogprobs(content=token_logprobs)

                    else:
                        choice_logprobs = None

1198
1199
                    finish_reason = content["meta_info"]["finish_reason"]

1200
1201
1202
1203
                    if is_first:
                        # First chunk with role
                        is_first = False
                        choice_data = ChatCompletionResponseStreamChoice(
1204
                            index=index,
1205
                            delta=DeltaMessage(role="assistant", content=""),
1206
                            finish_reason=(
1207
1208
1209
1210
1211
1212
                                finish_reason["type"] if finish_reason else ""
                            ),
                            matched_stop=(
                                finish_reason["matched"]
                                if finish_reason and "matched" in finish_reason
                                else None
1213
                            ),
yichuan~'s avatar
yichuan~ committed
1214
                            logprobs=choice_logprobs,
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
                        )
                        chunk = ChatCompletionStreamResponse(
                            id=content["meta_info"]["id"],
                            choices=[choice_data],
                            model=request.model,
                        )
                        yield f"data: {chunk.model_dump_json()}\n\n"

                    text = content["text"]
                    delta = text[len(stream_buffer) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
1225
                    stream_buffer = stream_buffer + delta
1226
                    choice_data = ChatCompletionResponseStreamChoice(
1227
                        index=index,
1228
                        delta=DeltaMessage(content=delta),
1229
1230
1231
1232
1233
                        finish_reason=(finish_reason["type"] if finish_reason else ""),
                        matched_stop=(
                            finish_reason["matched"]
                            if finish_reason and "matched" in finish_reason
                            else None
1234
                        ),
yichuan~'s avatar
yichuan~ committed
1235
                        logprobs=choice_logprobs,
1236
1237
1238
1239
1240
1241
                    )
                    chunk = ChatCompletionStreamResponse(
                        id=content["meta_info"]["id"],
                        choices=[choice_data],
                        model=request.model,
                    )
1242
1243
1244
1245
1246

                    is_firsts[index] = is_first
                    stream_buffers[index] = stream_buffer
                    n_prev_tokens[index] = n_prev_token

1247
                    yield f"data: {chunk.model_dump_json()}\n\n"
1248
                if request.stream_options and request.stream_options.include_usage:
1249
1250
1251
1252
1253
1254
1255
1256
                    total_prompt_tokens = sum(
                        tokens
                        for i, tokens in prompt_tokens.items()
                        if i % request.n == 0
                    )
                    total_completion_tokens = sum(
                        tokens for tokens in completion_tokens.values()
                    )
1257
                    usage = UsageInfo(
1258
1259
1260
                        prompt_tokens=total_prompt_tokens,
                        completion_tokens=total_completion_tokens,
                        total_tokens=total_prompt_tokens + total_completion_tokens,
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
                    )

                    final_usage_chunk = ChatCompletionStreamResponse(
                        id=str(uuid.uuid4().hex),
                        choices=[],
                        model=request.model,
                        usage=usage,
                    )
                    final_usage_data = final_usage_chunk.model_dump_json(
                        exclude_unset=True, exclude_none=True
                    )
                    yield f"data: {final_usage_data}\n\n"
1273
1274
1275
            except ValueError as e:
                error = create_streaming_error_response(str(e))
                yield f"data: {error}\n\n"
1276
1277
            yield "data: [DONE]\n\n"

1278
1279
1280
1281
1282
        return StreamingResponse(
            generate_stream_resp(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(adapted_request),
        )
1283
1284

    # Non-streaming response.
1285
1286
    try:
        ret = await tokenizer_manager.generate_request(
1287
1288
            adapted_request, raw_request
        ).__anext__()
1289
1290
    except ValueError as e:
        return create_error_response(str(e))
1291
1292
1293
    if not isinstance(ret, list):
        ret = [ret]

1294
1295
1296
    response = v1_chat_generate_response(
        request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
    )
1297

1298
1299
1300
    return response


1301
1302
1303
def v1_embedding_request(all_requests, tokenizer_manager):
    prompts = []
    sampling_params_list = []
Ying Sheng's avatar
Ying Sheng committed
1304
    first_prompt_type = type(all_requests[0].input)
1305
1306

    for request in all_requests:
Ying Sheng's avatar
Ying Sheng committed
1307
        prompt = request.input
1308
        assert (
1309
            type(prompt) is first_prompt_type
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
        ), "All prompts must be of the same type in file input settings"
        prompts.append(prompt)

    if len(all_requests) == 1:
        prompt = prompts[0]
        if isinstance(prompt, str) or isinstance(prompt[0], str):
            prompt_kwargs = {"text": prompt}
        else:
            prompt_kwargs = {"input_ids": prompt}
    else:
Baoyuan Qi's avatar
Baoyuan Qi committed
1320
        if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
            prompt_kwargs = {"text": prompts}
        else:
            prompt_kwargs = {"input_ids": prompts}

    adapted_request = EmbeddingReqInput(
        **prompt_kwargs,
    )

    if len(all_requests) == 1:
        return adapted_request, all_requests[0]
    return adapted_request, all_requests


Ying Sheng's avatar
Ying Sheng committed
1334
1335
1336
def v1_embedding_response(ret, model_path, to_file=False):
    embedding_objects = []
    prompt_tokens = 0
1337
    for idx, ret_item in enumerate(ret):
Ying Sheng's avatar
Ying Sheng committed
1338
1339
1340
        embedding_objects.append(
            EmbeddingObject(
                embedding=ret[idx]["embedding"],
1341
1342
1343
                index=idx,
            )
        )
Ying Sheng's avatar
Ying Sheng committed
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
        prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]

    return EmbeddingResponse(
        data=embedding_objects,
        model=model_path,
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            total_tokens=prompt_tokens,
        ),
    )
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370


async def v1_embeddings(tokenizer_manager, raw_request: Request):
    request_json = await raw_request.json()
    all_requests = [EmbeddingRequest(**request_json)]
    adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)

    try:
        ret = await tokenizer_manager.generate_request(
            adapted_request, raw_request
        ).__anext__()
    except ValueError as e:
        return create_error_response(str(e))

    if not isinstance(ret, list):
        ret = [ret]

Ying Sheng's avatar
Ying Sheng committed
1371
    response = v1_embedding_response(ret, tokenizer_manager.model_path)
1372
1373
1374
1375

    return response


1376
def to_openai_style_logprobs(
1377
1378
1379
1380
    input_token_logprobs=None,
    output_token_logprobs=None,
    input_top_logprobs=None,
    output_top_logprobs=None,
1381
1382
1383
1384
1385
1386
1387
1388
):
    ret_logprobs = LogProbs()

    def append_token_logprobs(token_logprobs):
        for logprob, _, token_text in token_logprobs:
            ret_logprobs.tokens.append(token_text)
            ret_logprobs.token_logprobs.append(logprob)

1389
            # Not supported yet
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
            ret_logprobs.text_offset.append(-1)

    def append_top_logprobs(top_logprobs):
        for tokens in top_logprobs:
            if tokens is not None:
                ret_logprobs.top_logprobs.append(
                    {token[2]: token[0] for token in tokens}
                )
            else:
                ret_logprobs.top_logprobs.append(None)

1401
1402
1403
1404
1405
1406
1407
1408
    if input_token_logprobs is not None:
        append_token_logprobs(input_token_logprobs)
    if output_token_logprobs is not None:
        append_token_logprobs(output_token_logprobs)
    if input_top_logprobs is not None:
        append_top_logprobs(input_top_logprobs)
    if output_top_logprobs is not None:
        append_top_logprobs(output_top_logprobs)
1409

Liangsheng Yin's avatar
Liangsheng Yin committed
1410
    return ret_logprobs