runners.py 7.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
17
import multiprocessing as mp
18
import os
19
20
21
22
23
24
25
26
27
28
from dataclasses import dataclass
from typing import List, Union

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from sglang.srt.server import Runtime

DEFAULT_PROMPTS = [
29
30
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
31
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
32
33
    "The capital of the United Kindom is",
    "Today is a sunny day and I like",
34
    "AI is a field of computer science focused on",
35
36
]

37
dirpath = os.path.dirname(__file__)
38
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
39
40
41
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

42
43
44
45
46
47
48
49
50
51
52
53
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
    else:
        raise NotImplementedError()


@dataclass
class ModelOutput:
54
55
56
57
58
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
    embed_logits: List[torch.Tensor] = None
59
60
61
62
63
64


class HFRunner:
    def __init__(
        self,
        model_path,
65
        torch_dtype,
66
        is_generation,
67
    ):
68
        self.is_generation = is_generation
69

70
71
72
73
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
74
75
76
77
78
79
80
81
82
83
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

84
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
85
86
87
88
89
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
        )

90
        if self.is_generation:
91
92
93
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
94
                trust_remote_code=False,
95
96
97
98
99
100
101
                low_cpu_mem_usage=True,
            ).cuda()
        else:
            from sentence_transformers import SentenceTransformer

            self.model = SentenceTransformer(
                model_path,
102
103
                model_kwargs={"torch_dtype": torch_dtype},
            )
104
105
106
107

        while True:
            prompts, max_new_tokens = in_queue.get()
            if prompts is not None:
108
                if self.is_generation:
109
110
111
112
113
114
115
116
117
118
119
120
121
                    output_strs = []
                    prefill_logprobs = []
                    for p in prompts:
                        if isinstance(p, str):
                            input_ids = self.tokenizer.encode(
                                p, return_tensors="pt"
                            ).cuda()
                        else:
                            input_ids = torch.tensor([p], device="cuda")

                        output_ids = self.model.generate(
                            input_ids, do_sample=False, max_new_tokens=max_new_tokens
                        )
122
123
124
                        output_strs.append(
                            self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
                        )
125
126

                        logits = self.model.forward(input_ids).logits[0]
127
128
129
130
131
132
133
134
                        logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                        logprobs, top_indices = torch.topk(
                            logprobs, k=NUM_TOP_LOGPROBS, dim=-1
                        )
                        # print("index", top_indices)
                        prefill_logprobs.append(logprobs.tolist())
                        del logits
                        del logprobs
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

                    out_queue.put(
                        ModelOutput(
                            output_strs=output_strs, top_input_logprobs=prefill_logprobs
                        )
                    )

                else:
                    logits = self.model.encode(prompts).tolist()

                    out_queue.put(ModelOutput(embed_logits=logits))

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
150
        max_new_tokens=8,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    ):
        self.in_queue.put((prompts, max_new_tokens))
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None


class SRTRunner:
    def __init__(
        self,
        model_path,
171
        torch_dtype,
172
        is_generation,
173
        tp_size=1,
174
        port=5157,
175
    ):
176
        self.is_generation = is_generation
177
178
179
180
        self.runtime = Runtime(
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
181
            port=port,
182
            mem_fraction_static=0.7,
183
184
            trust_remote_code=False,
            is_embedding=not self.is_generation,
185
186
187
188
189
        )

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
190
        max_new_tokens=8,
191
    ):
192
        if self.is_generation:
193
194
195
196
197
198
199
200
201
            # the return value contains logprobs from prefill
            output_strs = []
            top_input_logprobs = []
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
            for prompt in prompts:
                response = self.runtime.generate(
                    prompt,
                    sampling_params=sampling_params,
                    return_logprob=True,
202
                    logprob_start_len=0,
203
204
205
206
207
                    top_logprobs_num=NUM_TOP_LOGPROBS,
                )
                response = json.loads(response)
                output_strs.append(response["text"])
                top_input_logprobs.append(
208
                    [
209
210
211
212
213
214
215
216
217
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["input_top_logprobs"][1:]
                    ]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"]["output_top_logprobs"][0][
                                :NUM_TOP_LOGPROBS
                            ]
218
219
                        ]
                    ]
220
                )
221

222
223
224
225
            return ModelOutput(
                output_strs=output_strs, top_input_logprobs=top_input_logprobs
            )
        else:
226
227
228
            response = self.runtime.encode(prompts)
            response = json.loads(response)
            logits = [x["embedding"] for x in response]
229
            return ModelOutput(embed_logits=logits)
230
231
232
233
234
235
236

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.runtime.shutdown()
        del self.runtime