server.py 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
22
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import json
24
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
26
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
import sys
import threading
import time
30
from http import HTTPStatus
31
from typing import Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
32

33
# Fix a bug of Python threading
34
35
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
36
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
39
import requests
import uvicorn
import uvloop
40
from fastapi import FastAPI, File, Form, Request, UploadFile
41
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
42

Ying Sheng's avatar
Ying Sheng committed
43
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.constrained import disable_cache
Ying Sheng's avatar
Ying Sheng committed
45
from sglang.srt.hf_transformers_utils import get_tokenizer
46
from sglang.srt.managers.controller_multi import (
47
48
    start_controller_process as start_controller_process_multi,
)
49
50
from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller_single import (
51
52
    start_controller_process as start_controller_process_single,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
54
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
56
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
57
    load_chat_template_for_openai_api,
58
    v1_batches,
Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
    v1_chat_completions,
    v1_completions,
61
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
62
    v1_embeddings,
63
64
65
66
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
67
)
Mingyi's avatar
Mingyi committed
68
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
69
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
70
from sglang.srt.utils import (
71
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
    allocate_init_ports,
    assert_pkg_version,
74
    enable_show_time_cost,
75
    kill_child_process,
76
    maybe_set_triton_cache_manager,
77
78
    prepare_model,
    prepare_tokenizer,
79
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
)
81
82
from sglang.utils import get_exception_traceback

83
84
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
87

Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
91
app = FastAPI()
tokenizer_manager = None


92
93
94
95
96
97
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Lianmin Zheng's avatar
Lianmin Zheng committed
98
99
100
101
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
102
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105
    }
    return result

Cody Yu's avatar
Cody Yu committed
106

Liangsheng Yin's avatar
Liangsheng Yin committed
107
108
109
110
111
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
112
113
@app.get("/flush_cache")
async def flush_cache():
114
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
115
    return Response(
116
117
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
118
119
120
121
        status_code=200,
    )


122
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
123
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
124
    if obj.stream:
125

Lianmin Zheng's avatar
Lianmin Zheng committed
126
        async def stream_results():
127
128
129
130
131
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

135
136
137
138
139
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
140
141
142
143
144
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
145
146
147
148
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
149

Ying Sheng's avatar
Ying Sheng committed
150
151
152
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
169
@app.post("/v1/completions")
170
171
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
172
173


Cody Yu's avatar
Cody Yu committed
174
@app.post("/v1/chat/completions")
175
176
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
177

Lianmin Zheng's avatar
Lianmin Zheng committed
178

Ying Sheng's avatar
Ying Sheng committed
179
180
181
182
183
184
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


185
186
187
188
189
190
191
192
193
194
@app.get("/v1/models")
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


195
196
197
198
199
200
201
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


202
203
204
205
206
207
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


zhyncs's avatar
zhyncs committed
230
231
232
233
234
def launch_server(
    server_args: ServerArgs,
    model_overide_args: Optional[dict] = None,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
Mingyi's avatar
Mingyi committed
235
    """Launch an HTTP server."""
Lianmin Zheng's avatar
Lianmin Zheng committed
236
237
    global tokenizer_manager

238
239
240
241
242
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

243
244
    server_args.check_server_args()
    _set_envs_and_config(server_args)
245

Lianmin Zheng's avatar
Lianmin Zheng committed
246
247
    # Allocate ports
    server_args.port, server_args.additional_ports = allocate_init_ports(
248
249
250
        server_args.port,
        server_args.additional_ports,
        server_args.dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
251
    )
252
    ports = server_args.additional_ports
Lianmin Zheng's avatar
Lianmin Zheng committed
253
    port_args = PortArgs(
254
        tokenizer_port=ports[0],
Mingyi's avatar
Mingyi committed
255
        controller_port=ports[1],
256
        detokenizer_port=ports[2],
Mingyi's avatar
Mingyi committed
257
        nccl_ports=ports[3:],
Lianmin Zheng's avatar
Lianmin Zheng committed
258
    )
259
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
260

261
262
263
264
    # Use model from www.modelscope.cn, first download the model.
    server_args.model_path = prepare_model(server_args.model_path)
    server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)

265
    # Launch processes for multi-node tensor parallelism
266
267
    if server_args.nnodes > 1:
        if server_args.node_rank != 0:
268
            tp_size_local = server_args.tp_size // server_args.nnodes
zhyncs's avatar
zhyncs committed
269
270
271
272
273
274
275
276
277
278
279
280
281
            gpu_ids = [
                i for _ in range(server_args.nnodes) for i in range(tp_size_local)
            ]
            tp_rank_range = list(
                range(
                    server_args.node_rank * tp_size_local,
                    (server_args.node_rank + 1) * tp_size_local,
                )
            )
            procs = launch_tp_servers(
                gpu_ids,
                tp_rank_range,
                server_args,
Mingyi's avatar
Mingyi committed
282
                ports[3],
zhyncs's avatar
zhyncs committed
283
284
                model_overide_args,
            )
285
286
287
            while True:
                pass

Lianmin Zheng's avatar
Lianmin Zheng committed
288
    # Launch processes
Yuanhan Zhang's avatar
Yuanhan Zhang committed
289
    tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
290
291
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
Mingyi's avatar
Mingyi committed
292
    pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

295
296
297
298
    if server_args.dp_size == 1:
        start_process = start_controller_process_single
    else:
        start_process = start_controller_process_multi
Mingyi's avatar
Mingyi committed
299
    proc_controller = mp.Process(
300
        target=start_process,
Mingyi's avatar
Mingyi committed
301
        args=(server_args, port_args, pipe_controller_writer, model_overide_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
302
    )
Mingyi's avatar
Mingyi committed
303
    proc_controller.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
307
308
309
310
311
312
313
314
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
Mingyi's avatar
Mingyi committed
315
    controller_init_state = pipe_controller_reader.recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
    detoken_init_state = pipe_detoken_reader.recv()

Mingyi's avatar
Mingyi committed
318
319
    if controller_init_state != "init ok" or detoken_init_state != "init ok":
        proc_controller.kill()
Lianmin Zheng's avatar
Lianmin Zheng committed
320
        proc_detoken.kill()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
321
        print(
zhyncs's avatar
zhyncs committed
322
323
            f"Initialization failed. controller_init_state: {controller_init_state}",
            flush=True,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
324
325
326
327
328
        )
        print(
            f"Initialization failed. detoken_init_state: {detoken_init_state}",
            flush=True,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
329
        sys.exit(1)
Mingyi's avatar
Mingyi committed
330
    assert proc_controller.is_alive() and proc_detoken.is_alive()
Lianmin Zheng's avatar
Lianmin Zheng committed
331

332
333
334
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
335

336
    # Send a warmup request
zhyncs's avatar
zhyncs committed
337
338
339
    t = threading.Thread(
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
    )
340
    t.start()
341
342

    # Listen for requests
343
    try:
Lianmin Zheng's avatar
Lianmin Zheng committed
344
345
346
347
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
348
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
            timeout_keep_alive=5,
            loop="uvloop",
        )
352
353
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355


356
357
358
359
360
361
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
362
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

    # Set ulimit
    set_ulimit()

    # Enable show time cost for debugging
    if server_args.show_time_cost:
        enable_show_time_cost()

    # Disable disk cache
    if server_args.disable_disk_cache:
        disable_cache()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
    if not server_args.disable_flashinfer:
        assert_pkg_version(
            "flashinfer",
384
            "0.1.5",
385
386
387
388
389
390
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )


Mingyi's avatar
Mingyi committed
391
392
393
394
def _wait_and_warmup(server_args, pipe_finish_writer):
    headers = {}
    url = server_args.url()
    if server_args.api_key:
395
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
396
397

    # Wait until the server is launched
398
    success = False
Mingyi's avatar
Mingyi committed
399
    for _ in range(120):
400
        time.sleep(1)
Mingyi's avatar
Mingyi committed
401
        try:
402
403
404
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
            assert res.status_code == 200, f"{res}"
            success = True
Mingyi's avatar
Mingyi committed
405
            break
406
407
        except (AssertionError, requests.exceptions.RequestException) as e:
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
408
            pass
409
    model_info = res.json()
Mingyi's avatar
Mingyi committed
410

411
412
413
414
415
416
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
        print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
        sys.exit(1)

Mingyi's avatar
Mingyi committed
417
    # Send a warmup request
418
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
419
    max_new_tokens = 8 if model_info["is_generation"] else 1
420
421
422
423
424
425
426
427
428
429
430
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
431
432
433
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
434
                url + request_name,
435
                json=json_data,
Mingyi's avatar
Mingyi committed
436
437
438
                headers=headers,
                timeout=600,
            )
439
            assert res.status_code == 200, f"{res}"
Mingyi's avatar
Mingyi committed
440
    except Exception as e:
441
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
442
        if pipe_finish_writer is not None:
443
444
445
            pipe_finish_writer.send(last_traceback)
        print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
        sys.exit(1)
Mingyi's avatar
Mingyi committed
446
447
448
449
450
451

    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
        pipe_finish_writer.send("init ok")


Lianmin Zheng's avatar
Lianmin Zheng committed
452
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
453
454
455
456
457
458
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
    def __init__(
        self,
461
        log_level: str = "error",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
462
        model_overide_args: Optional[dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
465
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
466
        """See the arguments in server_args.py::ServerArgs"""
467
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
468
469
470

        # Pre-allocate ports
        self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
471
472
            self.server_args.port,
            self.server_args.additional_ports,
473
            self.server_args.dp_size,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
474
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
475

Ying Sheng's avatar
Ying Sheng committed
476
477
478
479
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
480
481
482

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
483
484
        proc = mp.Process(
            target=launch_server,
Mingyi's avatar
Mingyi committed
485
            args=(self.server_args, model_overide_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
486
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
487
        proc.start()
488
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
489
490
        self.pid = proc.pid

491
492
493
494
495
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
        if init_state != "init ok":
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
498
499
500
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
503
504
505

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
506
            kill_child_process(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
            self.pid = None

509
510
511
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
512
513
514
515
516
517
518
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

519
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
520
521
        self,
        prompt: str,
522
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
523
    ):
524
525
526
527
528
529
530
531
532
533
534
535
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
536
537
538
539
540
541
542
543
544
545
546
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
547
548
549
550
551
552
553
                        if hasattr(data, "text"):
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
554

555
556
557
558
559
560
561
    add_request = async_generate

    def generate(
        self,
        prompt: str,
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
562
        logprob_start_len: Optional[Union[List[int], int]] = None,
563
564
565
566
567
568
        top_logprobs_num: Optional[Union[List[int], int]] = None,
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
569
            "logprob_start_len": logprob_start_len,
570
571
572
573
574
575
576
577
            "top_logprobs_num": top_logprobs_num,
        }
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

578
579
580
581
582
583
584
585
586
587
588
589
590
    def encode(
        self,
        prompt: str,
    ):
        json_data = {
            "text": prompt,
        }
        response = requests.post(
            self.url + "/encode",
            json=json_data,
        )
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
591
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
592
        self.shutdown()