llm.py 5.62 KB
Newer Older
1
from typing import List, Optional, Union
2

3
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
4
5
6
7
8
9
10
11
12
13
from tqdm import tqdm

from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.utils import Counter


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
    serving, use the `AsyncLLMServer` class instead.
    NOTE: For the comprehensive list of arguments, see `ServerArgs`.

    Args:
        model: The name or path of a HuggingFace Transformers model.
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        seed: The seed to initialize the random number generator for sampling.
    """
37
38
39
40
41

    def __init__(
        self,
        model: str,
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        dtype: str = "auto",
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        seed: int = 0,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
        server_args = ServerArgs(
            model=model,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            seed=seed,
            **kwargs,
        )
        self.llm_server = LLMServer.from_server_args(server_args)
        self.request_counter = Counter()

58
59
60
61
62
    def get_tokenizer(
        self,
    ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
        return self.llm_server.tokenizer

63
64
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        prompts: Optional[Union[str, List[str]]] = None,
66
        sampling_params: Optional[SamplingParams] = None,
67
        prompt_token_ids: Optional[List[List[int]]] = None,
68
69
        use_tqdm: bool = True,
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
92
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
93
            prompts = [prompts]
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
        if prompts is not None and prompt_token_ids is not None:
            if len(prompts) != len(prompt_token_ids):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")
98
        if sampling_params is None:
99
            # Use default sampling params.
100
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
101

102
        # Add requests to the server.
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
106
107
108
        if prompts is not None:
            num_requests = len(prompts)
        else:
            num_requests = len(prompt_token_ids)
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
109
110
111
112
            if prompt_token_ids is None:
                token_ids = None
            else:
                token_ids = prompt_token_ids[i]
113
114
            self._add_request(prompt, sampling_params, token_ids)
        return self._run_server(use_tqdm)
115

116
117
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        prompt: Optional[str],
119
120
121
122
123
124
125
126
127
128
129
130
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
        self.llm_server.add_request(request_id, prompt, sampling_params,
                                    prompt_token_ids)

    def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_server.get_num_unfinished_requests()
            pbar = tqdm(total=num_requests, desc="Processed prompts")
131
132
133
134
135
        # Run the server.
        outputs: List[RequestOutput] = []
        while self.llm_server.has_unfinished_requests():
            step_outputs = self.llm_server.step()
            for output in step_outputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
136
                if output.finished():
137
138
139
140
141
142
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        return outputs