utils.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json
import os
6
from dataclasses import dataclass
7
8

import torch
9
from safetensors.torch import save_file
10

11
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
12
13
14


class DummyLoRAManager:
15
    def __init__(self, device: torch.device = "cuda:0"):
16
        super().__init__()
17
        self._loras: dict[str, LoRALayerWeights] = {}
18
        self._device = device
19
20
21
22

    def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
        self._loras[module_name] = lora

23
24
    def get_module_lora(self, module_name: str) -> LoRALayerWeights:
        return self._loras[module_name]
25

26
27
28
29
30
31
    def init_random_lora(
        self,
        module_name: str,
        weight: torch.Tensor,
        rank: int = 8,
    ):
32
33
34
35
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
36
37
38
39
40
41
            lora_a=torch.rand(
                [rank, weight.shape[1]], dtype=weight.dtype, device=self._device
            ),
            lora_b=torch.rand(
                [weight.shape[0], rank], dtype=weight.dtype, device=self._device
            ),
42
43
44
45
46
        )
        self.set_module_lora(module_name, lora)

        return lora

47
48
49
50
51
52
53
54
55
    def init_lora(
        self,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank=8,
        noop=False,
        embeddings_tensor=None,
    ):
56
57
58
59
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
60
61
            lora_a=torch.rand([rank, input_dim], device="cuda"),
            lora_b=torch.rand([output_dim, input_dim], device="cuda"),
62
63
64
65
66
67
68
69
70
71
72
73
            embeddings_tensor=embeddings_tensor,
        )
        self.set_module_lora(module_name, lora)
        return lora

    def reset_lora(self):
        self._loras = {}

    def init_packed_lora(
        self,
        module_name: str,
        input_dim: int,
74
        output_dims: list[int],
75
        noop_lora_index: list[int] | None = None,
76
        rank: int = 8,
77
    ):
78
        base_loras: list[LoRALayerWeights] = []
79
        noop_lora_index_set = set(noop_lora_index or [])
80
81
82
83
84
85
86

        for i, out_dim in enumerate(output_dims):
            base_lora = self.init_lora(
                module_name + "_000_" + str(i),
                input_dim,
                out_dim,
                rank=rank,
87
                noop=i in noop_lora_index_set,
88
89
90
91
92
            )
            base_loras.append(base_lora)
        packed_lora = PackedLoRALayerWeights.pack(base_loras)
        self.set_module_lora(module_name, packed_lora)
        return packed_lora
93
94
95
96
97
98
99
100
101
102
103


def assert_close(a, b):
    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[a.dtype]
    torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


104
105
106
@dataclass
class PunicaTensors:
    inputs_tensor: torch.Tensor
107
    lora_weights: torch.Tensor | list[torch.Tensor]
108
109
110
111
112
113
114
    our_out_tensor: torch.Tensor
    ref_out_tensor: torch.Tensor
    b_seq_start_loc: torch.Tensor
    prompt_lora_mapping: torch.Tensor
    seq_len_tensor: torch.Tensor
    token_lora_mapping: torch.Tensor

115
    def meta(self) -> tuple[int, int]:
116
117
118
119
120
121
122
123
124
125
126
127
128
        """
        Infer max_seq_length and token_nums from the tensors
        and return them.
        """
        max_seq_length = self.seq_len_tensor.max()
        token_nums = self.seq_len_tensor.sum().item()
        if isinstance(max_seq_length, tuple):
            max_seq_length = max_seq_length[0].item()
        else:
            max_seq_length = max_seq_length.item()
        return max_seq_length, token_nums


129
130
131
132
133
134
135
136
137
def generate_data(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    op_type,
    device,
138
) -> PunicaTensors:
139
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
140
141
142
143
144
145
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    if op_type == "shrink":
146
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
147
148
149
150
151
        lora_weights = torch.rand(
            (lora_nums, max_rank, hidden_size),  # col-major
            dtype=dtype,
        ).to(device)
        # shrink op need atomic_add, so output is initinized by 0
152
153
154
        ref_out_tensor = torch.zeros(
            (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device
        )
155
        # NOTE  shrink kernel using torch.float32 as output type
156
157
158
        our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to(
            device
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    else:
        inputs_tensor = torch.rand(
            (total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        lora_weights = torch.rand(
            (lora_nums, hidden_size, max_rank),  # col-major
            dtype=dtype,
        ).to(device)
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
        ref_out_tensor = torch.rand(
            (total_tokens, hidden_size),
            dtype=dtype,
        ).to(device)
        # Ensure the same input.
        our_out_tensor = ref_out_tensor.clone()
176
177
178
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    ).to(device)
179
180
181
182
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
183
184
185
        indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_(
            lora_index
        )
186
        current_offset += seq_len_tensor[b_id].item()
187
188

    return PunicaTensors(
189
190
191
192
193
194
195
196
197
198
199
        inputs_tensor,
        lora_weights,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


200
201
202
203
204
205
206
207
208
def generate_data_for_expand_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    nslices,
    device,
209
) -> PunicaTensors:
210
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    inputs_tensor = torch.rand(
        (total_tokens, max_rank),
        dtype=dtype,
    ).to(device)
    lora_weights_lst = []
    for _ in range(nslices):
        lora_weights_lst.append(
            torch.rand(
                (lora_nums, hidden_size, max_rank),  # col-major
                dtype=dtype,
226
227
            ).to(device)
        )
228
229
    # expand op needs to complete y+=a@lora_b, so output is
    # initinized randomly
230
231
232
    ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to(
        device
    )
233
234
    # Ensure the same input.
    our_out_tensor = ref_out_tensor.clone()
235
236
237
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
238
239
240
241
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
242
243
244
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
245
246
247
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
248
    return PunicaTensors(
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


def generate_data_for_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    nslices,
    dtype,
    op_type,
    device,
270
) -> PunicaTensors:
271
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
272
273
274
275
276
277
278
279
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()

    lora_weights_lst = []
    if op_type == "shrink":
280
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
281
282
283
284
285
286
287

        for _ in range(nslices):
            if op_type == "shrink":
                lora_weights_lst.append(
                    torch.rand(
                        (lora_nums, max_rank, hidden_size),  # col-major
                        dtype=dtype,
288
289
                    ).to(device)
                )
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        # NOTE  shrink kernel using torch.float32 as output type
        # shrink op need atomic_add, so output is initinized by 0
        our_out_tensor = torch.zeros(
            (nslices, total_tokens, max_rank),
            dtype=torch.float32,
        ).to(device)
    else:
        inputs_tensor = torch.rand(
            (nslices, total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        for _ in range(nslices):
            lora_weights_lst.append(
                torch.rand(
                    (lora_nums, hidden_size, max_rank),  # col-major
                    dtype=dtype,
306
307
                ).to(device)
            )
308
309
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
310
311
312
        our_out_tensor = torch.rand(
            (total_tokens, hidden_size * nslices), dtype=dtype
        ).to(device)
313
314
315

    # Ensure the same input.
    ref_out_tensor = our_out_tensor.clone()
316
317
318
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
319
320
321
322
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
323
324
325
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
326
327
328
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
329
    return PunicaTensors(
330
331
332
333
334
335
336
337
338
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382


def create_peft_lora(
    model: torch.nn.Module,
    save_dir: str,
    target_modules: list[str],
    rank: int = 8,
    alpha: int = 16,
    dropout: float = 0.1,
    lora_dtype: torch.dtype = torch.float16,
) -> dict[str, torch.Tensor]:
    lora_weights = {}
    adapter_config = {
        "peft_type": "LORA",
        "auto_mapping": None,
        "base_model_name_or_path": "dummy_model",
        "revision": None,
        "task_type": "CAUSAL_LM",
        "inference_mode": False,
        "r": rank,
        "lora_alpha": alpha,
        "lora_dropout": dropout,
        "fan_in_fan_out": False,
        "bias": "none",
        "modules_to_save": None,
        "init_lora_weights": True,
        "layers_to_transform": None,
        "layers_pattern": None,
        "target_modules": target_modules,
        "exclude_modules": None,
        "use_rslora": False,
        "use_dora": False,
        "loftq_config": None,
    }

    for module_name in target_modules:
        module = model
        for attr in module_name.split("."):
            module = getattr(module, attr)

        if hasattr(module, "input_size") and hasattr(module, "output_size"):
            in_features = module.input_size
            out_features = module.output_size

383
        elif hasattr(module, "embedding_dim") and hasattr(module, "num_embeddings"):
384
385
386
387
            # ParallelLMHead
            in_features = module.embedding_dim
            out_features = module.num_embeddings
        else:
388
            raise ValueError(f"Unable to determine dimensions for module {module_name}")
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

        lora_A = torch.randn(rank, in_features, dtype=lora_dtype)

        torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)

        lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)

        # PEFT style
        lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
        lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B

    config_path = os.path.join(save_dir, "adapter_config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(adapter_config, f, indent=2, ensure_ascii=False)

    weights_path = os.path.join(save_dir, "adapter_model.safetensors")
    save_file(lora_weights, weights_path)

    return lora_weights