llm.py 5.76 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
4
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
11
12
13


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
21
22
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
23
    serving, use the `AsyncLLMEngine` class instead.
Zhuohan Li's avatar
Zhuohan Li committed
24
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30

    Args:
        model: The name or path of a HuggingFace Transformers model.
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        seed: The seed to initialize the random number generator for sampling.
    """
37
38
39
40
41

    def __init__(
        self,
        model: str,
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        dtype: str = "auto",
43
44
45
46
47
        seed: int = 0,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
48
        engine_args = EngineArgs(
49
50
51
52
53
54
            model=model,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            seed=seed,
            **kwargs,
        )
Zhuohan Li's avatar
Zhuohan Li committed
55
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
56
57
        self.request_counter = Counter()

58
59
60
    def get_tokenizer(
        self,
    ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
Zhuohan Li's avatar
Zhuohan Li committed
61
        return self.llm_engine.tokenizer
62

63
64
65
66
67
68
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
        self.llm_engine.tokenizer = tokenizer

69
70
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
71
        prompts: Optional[Union[str, List[str]]] = None,
72
        sampling_params: Optional[SamplingParams] = None,
73
        prompt_token_ids: Optional[List[List[int]]] = None,
74
75
        use_tqdm: bool = True,
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
98
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            prompts = [prompts]
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
        if prompts is not None and prompt_token_ids is not None:
            if len(prompts) != len(prompt_token_ids):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")
104
        if sampling_params is None:
105
            # Use default sampling params.
106
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
107

Zhuohan Li's avatar
Zhuohan Li committed
108
        # Add requests to the engine.
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
112
113
114
        if prompts is not None:
            num_requests = len(prompts)
        else:
            num_requests = len(prompt_token_ids)
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
115
116
117
118
            if prompt_token_ids is None:
                token_ids = None
            else:
                token_ids = prompt_token_ids[i]
119
            self._add_request(prompt, sampling_params, token_ids)
Zhuohan Li's avatar
Zhuohan Li committed
120
        return self._run_engine(use_tqdm)
121

122
123
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        prompt: Optional[str],
125
126
127
128
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
Zhuohan Li's avatar
Zhuohan Li committed
129
        self.llm_engine.add_request(request_id, prompt, sampling_params,
130
131
                                    prompt_token_ids)

Zhuohan Li's avatar
Zhuohan Li committed
132
    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
133
134
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
135
            num_requests = self.llm_engine.get_num_unfinished_requests()
136
            pbar = tqdm(total=num_requests, desc="Processed prompts")
Zhuohan Li's avatar
Zhuohan Li committed
137
        # Run the engine.
138
        outputs: List[RequestOutput] = []
Zhuohan Li's avatar
Zhuohan Li committed
139
140
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
141
            for output in step_outputs:
142
                if output.finished:
143
144
145
146
147
148
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        return outputs