llm.py 6.13 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
4
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
11
12
13


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
21
22
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
23
    serving, use the `AsyncLLMEngine` class instead.
Zhuohan Li's avatar
Zhuohan Li committed
24
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27

    Args:
        model: The name or path of a HuggingFace Transformers model.
28
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
29
30
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        seed: The seed to initialize the random number generator for sampling.
    """
40
41
42
43

    def __init__(
        self,
        model: str,
44
        tokenizer: Optional[str] = None,
45
        tokenizer_mode: str = "auto",
46
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
47
        dtype: str = "auto",
48
49
50
51
52
        seed: int = 0,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
53
        engine_args = EngineArgs(
54
            model=model,
55
            tokenizer=tokenizer,
56
            tokenizer_mode=tokenizer_mode,
57
58
59
60
61
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            seed=seed,
            **kwargs,
        )
Zhuohan Li's avatar
Zhuohan Li committed
62
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
63
64
        self.request_counter = Counter()

65
    def get_tokenizer(
66
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
Zhuohan Li's avatar
Zhuohan Li committed
67
        return self.llm_engine.tokenizer
68

69
70
71
72
73
74
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
        self.llm_engine.tokenizer = tokenizer

75
76
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        prompts: Optional[Union[str, List[str]]] = None,
78
        sampling_params: Optional[SamplingParams] = None,
79
        prompt_token_ids: Optional[List[List[int]]] = None,
80
81
        use_tqdm: bool = True,
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
104
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
105
            prompts = [prompts]
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
109
        if prompts is not None and prompt_token_ids is not None:
            if len(prompts) != len(prompt_token_ids):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")
110
        if sampling_params is None:
111
            # Use default sampling params.
112
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
113

Zhuohan Li's avatar
Zhuohan Li committed
114
        # Add requests to the engine.
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
119
120
        if prompts is not None:
            num_requests = len(prompts)
        else:
            num_requests = len(prompt_token_ids)
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
121
122
123
124
            if prompt_token_ids is None:
                token_ids = None
            else:
                token_ids = prompt_token_ids[i]
125
            self._add_request(prompt, sampling_params, token_ids)
Zhuohan Li's avatar
Zhuohan Li committed
126
        return self._run_engine(use_tqdm)
127

128
129
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        prompt: Optional[str],
131
132
133
134
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
Zhuohan Li's avatar
Zhuohan Li committed
135
        self.llm_engine.add_request(request_id, prompt, sampling_params,
136
137
                                    prompt_token_ids)

Zhuohan Li's avatar
Zhuohan Li committed
138
    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
139
140
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
141
            num_requests = self.llm_engine.get_num_unfinished_requests()
142
            pbar = tqdm(total=num_requests, desc="Processed prompts")
Zhuohan Li's avatar
Zhuohan Li committed
143
        # Run the engine.
144
        outputs: List[RequestOutput] = []
Zhuohan Li's avatar
Zhuohan Li committed
145
146
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
147
            for output in step_outputs:
148
                if output.finished:
149
150
151
152
153
154
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        return outputs