llm.py 2.87 KB
Newer Older
1
from typing import List, Optional, Union
2

3
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from tqdm import tqdm

from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.utils import Counter


class LLM:

    def __init__(
        self,
        model: str,
        tensor_parallel_size: int = 1,
        dtype: str = "default",
        seed: int = 0,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
        server_args = ServerArgs(
            model=model,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            seed=seed,
            **kwargs,
        )
        self.llm_server = LLMServer.from_server_args(server_args)
        self.request_counter = Counter()

35
36
37
38
39
    def get_tokenizer(
        self,
    ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
        return self.llm_server.tokenizer

40
41
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        prompts: Union[str, List[str]],
43
        sampling_params: Optional[SamplingParams] = None,
44
        prompt_token_ids: Optional[List[List[int]]] = None,
45
46
        use_tqdm: bool = True,
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
        if isinstance(prompts, str):
            prompts = [prompts]
49
        if sampling_params is None:
50
            # Use default sampling params.
51
52
            sampling_params = SamplingParams()
        # Add requests to the server.
53
54
55
56
57
58
        for i in range(len(prompts)):
            prompt = prompts[i]
            if prompt_token_ids is None:
                token_ids = None
            else:
                token_ids = prompt_token_ids[i]
59
60
            self._add_request(prompt, sampling_params, token_ids)
        return self._run_server(use_tqdm)
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def _add_request(
        self,
        prompt: str,
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
        self.llm_server.add_request(request_id, prompt, sampling_params,
                                    prompt_token_ids)

    def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_server.get_num_unfinished_requests()
            pbar = tqdm(total=num_requests, desc="Processed prompts")
77
78
79
80
81
        # Run the server.
        outputs: List[RequestOutput] = []
        while self.llm_server.has_unfinished_requests():
            step_outputs = self.llm_server.step()
            for output in step_outputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
82
                if output.finished():
83
84
85
86
87
88
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        return outputs