predictor.py 3.83 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) Opendatalab. All rights reserved.

import time

from loguru import logger

from .base_predictor import (
    DEFAULT_MAX_NEW_TOKENS,
    DEFAULT_NO_REPEAT_NGRAM_SIZE,
    DEFAULT_PRESENCE_PENALTY,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    BasePredictor,
)
from .sglang_client_predictor import SglangClientPredictor

hf_loaded = False
try:
    from .hf_predictor import HuggingfacePredictor

    hf_loaded = True
except ImportError as e:
    logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")

engine_loaded = False
try:
    from sglang.srt.server_args import ServerArgs

    from .sglang_engine_predictor import SglangEnginePredictor

    engine_loaded = True
except Exception as e:
    logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")


def get_predictor(
    backend: str = "sglang-client",
    model_path: str | None = None,
    server_url: str | None = None,
    temperature: float = DEFAULT_TEMPERATURE,
    top_p: float = DEFAULT_TOP_P,
    top_k: int = DEFAULT_TOP_K,
    repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
    presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
    no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    http_timeout: int = 600,
    **kwargs,
) -> BasePredictor:
    start_time = time.time()

    if backend == "transformers":
        if not model_path:
            raise ValueError("model_path must be provided for transformers backend.")
        if not hf_loaded:
            raise ImportError(
                "transformers is not installed, so huggingface backend cannot be used. "
                "If you need to use huggingface backend, please install transformers first."
            )
        predictor = HuggingfacePredictor(
            model_path=model_path,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            presence_penalty=presence_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
            **kwargs,
        )
    elif backend == "sglang-engine":
        if not model_path:
            raise ValueError("model_path must be provided for sglang-engine backend.")
        if not engine_loaded:
            raise ImportError(
                "sglang is not installed, so sglang-engine backend cannot be used. "
                "If you need to use sglang-engine backend for inference, "
                "please install sglang[all]==0.4.7 or a newer version."
            )
        predictor = SglangEnginePredictor(
            server_args=ServerArgs(model_path, **kwargs),
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            presence_penalty=presence_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
        )
    elif backend == "sglang-client":
        if not server_url:
            raise ValueError("server_url must be provided for sglang-client backend.")
        predictor = SglangClientPredictor(
            server_url=server_url,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            presence_penalty=presence_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
            http_timeout=http_timeout,
        )
    else:
        raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")

    elapsed = round(time.time() - start_time, 2)
    logger.info(f"get_predictor cost: {elapsed}s")
    return predictor