utils.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json
import os
6
from dataclasses import dataclass
7
from typing import Optional, Union
8
9

import torch
10
from safetensors.torch import save_file
11

12
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
13
14
15


class DummyLoRAManager:
16
    def __init__(self, device: torch.device = "cuda:0"):
17
        super().__init__()
18
        self._loras: dict[str, LoRALayerWeights] = {}
19
        self._device = device
20
21
22
23

    def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
        self._loras[module_name] = lora

24
25
    def get_module_lora(self, module_name: str) -> LoRALayerWeights:
        return self._loras[module_name]
26

27
28
29
30
31
32
33
    def init_random_lora(
        self,
        module_name: str,
        weight: torch.Tensor,
        rank: int = 8,
        generate_embeddings_tensor: int = 0,
    ):
34
35
36
37
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
38
39
40
41
42
43
            lora_a=torch.rand(
                [rank, weight.shape[1]], dtype=weight.dtype, device=self._device
            ),
            lora_b=torch.rand(
                [weight.shape[0], rank], dtype=weight.dtype, device=self._device
            ),
44
45
        )
        if generate_embeddings_tensor:
46
47
48
49
50
51
            lora.embeddings_tensor = torch.rand(
                5,
                generate_embeddings_tensor,
                dtype=weight.dtype,
                device=self._device,
            )
52
53
54
55
        self.set_module_lora(module_name, lora)

        return lora

56
57
58
59
60
61
62
63
64
    def init_lora(
        self,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank=8,
        noop=False,
        embeddings_tensor=None,
    ):
65
66
67
68
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
69
70
            lora_a=torch.rand([rank, input_dim], device="cuda"),
            lora_b=torch.rand([output_dim, input_dim], device="cuda"),
71
72
73
74
75
76
77
78
79
80
81
82
            embeddings_tensor=embeddings_tensor,
        )
        self.set_module_lora(module_name, lora)
        return lora

    def reset_lora(self):
        self._loras = {}

    def init_packed_lora(
        self,
        module_name: str,
        input_dim: int,
83
84
        output_dims: list[int],
        noop_lora_index: Optional[list[int]] = None,
85
        rank: int = 8,
86
    ):
87
        base_loras: list[LoRALayerWeights] = []
88
        noop_lora_index_set = set(noop_lora_index or [])
89
90
91
92
93
94
95

        for i, out_dim in enumerate(output_dims):
            base_lora = self.init_lora(
                module_name + "_000_" + str(i),
                input_dim,
                out_dim,
                rank=rank,
96
                noop=i in noop_lora_index_set,
97
98
99
100
101
            )
            base_loras.append(base_lora)
        packed_lora = PackedLoRALayerWeights.pack(base_loras)
        self.set_module_lora(module_name, packed_lora)
        return packed_lora
102
103
104
105
106
107
108
109
110
111
112


def assert_close(a, b):
    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[a.dtype]
    torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


113
114
115
@dataclass
class PunicaTensors:
    inputs_tensor: torch.Tensor
116
    lora_weights: Union[torch.Tensor, list[torch.Tensor]]
117
118
119
120
121
122
123
    our_out_tensor: torch.Tensor
    ref_out_tensor: torch.Tensor
    b_seq_start_loc: torch.Tensor
    prompt_lora_mapping: torch.Tensor
    seq_len_tensor: torch.Tensor
    token_lora_mapping: torch.Tensor

124
    def meta(self) -> tuple[int, int]:
125
126
127
128
129
130
131
132
133
134
135
136
137
        """
        Infer max_seq_length and token_nums from the tensors
        and return them.
        """
        max_seq_length = self.seq_len_tensor.max()
        token_nums = self.seq_len_tensor.sum().item()
        if isinstance(max_seq_length, tuple):
            max_seq_length = max_seq_length[0].item()
        else:
            max_seq_length = max_seq_length.item()
        return max_seq_length, token_nums


138
139
140
141
142
143
144
145
146
def generate_data(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    op_type,
    device,
147
) -> PunicaTensors:
148
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
149
150
151
152
153
154
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    if op_type == "shrink":
155
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
156
157
158
159
160
        lora_weights = torch.rand(
            (lora_nums, max_rank, hidden_size),  # col-major
            dtype=dtype,
        ).to(device)
        # shrink op need atomic_add, so output is initinized by 0
161
162
163
        ref_out_tensor = torch.zeros(
            (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device
        )
164
        # NOTE  shrink kernel using torch.float32 as output type
165
166
167
        our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to(
            device
        )
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    else:
        inputs_tensor = torch.rand(
            (total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        lora_weights = torch.rand(
            (lora_nums, hidden_size, max_rank),  # col-major
            dtype=dtype,
        ).to(device)
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
        ref_out_tensor = torch.rand(
            (total_tokens, hidden_size),
            dtype=dtype,
        ).to(device)
        # Ensure the same input.
        our_out_tensor = ref_out_tensor.clone()
185
186
187
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    ).to(device)
188
189
190
191
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
192
193
194
        indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_(
            lora_index
        )
195
        current_offset += seq_len_tensor[b_id].item()
196
197

    return PunicaTensors(
198
199
200
201
202
203
204
205
206
207
208
        inputs_tensor,
        lora_weights,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


209
210
211
212
213
214
215
216
217
def generate_data_for_expand_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    nslices,
    device,
218
) -> PunicaTensors:
219
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    inputs_tensor = torch.rand(
        (total_tokens, max_rank),
        dtype=dtype,
    ).to(device)
    lora_weights_lst = []
    for _ in range(nslices):
        lora_weights_lst.append(
            torch.rand(
                (lora_nums, hidden_size, max_rank),  # col-major
                dtype=dtype,
235
236
            ).to(device)
        )
237
238
    # expand op needs to complete y+=a@lora_b, so output is
    # initinized randomly
239
240
241
    ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to(
        device
    )
242
243
    # Ensure the same input.
    our_out_tensor = ref_out_tensor.clone()
244
245
246
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
247
248
249
250
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
251
252
253
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
254
255
256
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
257
    return PunicaTensors(
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


def generate_data_for_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    nslices,
    dtype,
    op_type,
    device,
279
) -> PunicaTensors:
280
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
281
282
283
284
285
286
287
288
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()

    lora_weights_lst = []
    if op_type == "shrink":
289
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
290
291
292
293
294
295
296

        for _ in range(nslices):
            if op_type == "shrink":
                lora_weights_lst.append(
                    torch.rand(
                        (lora_nums, max_rank, hidden_size),  # col-major
                        dtype=dtype,
297
298
                    ).to(device)
                )
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        # NOTE  shrink kernel using torch.float32 as output type
        # shrink op need atomic_add, so output is initinized by 0
        our_out_tensor = torch.zeros(
            (nslices, total_tokens, max_rank),
            dtype=torch.float32,
        ).to(device)
    else:
        inputs_tensor = torch.rand(
            (nslices, total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        for _ in range(nslices):
            lora_weights_lst.append(
                torch.rand(
                    (lora_nums, hidden_size, max_rank),  # col-major
                    dtype=dtype,
315
316
                ).to(device)
            )
317
318
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
319
320
321
        our_out_tensor = torch.rand(
            (total_tokens, hidden_size * nslices), dtype=dtype
        ).to(device)
322
323
324

    # Ensure the same input.
    ref_out_tensor = our_out_tensor.clone()
325
326
327
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
328
329
330
331
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
332
333
334
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
335
336
337
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
338
    return PunicaTensors(
339
340
341
342
343
344
345
346
347
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


def create_peft_lora(
    model: torch.nn.Module,
    save_dir: str,
    target_modules: list[str],
    rank: int = 8,
    alpha: int = 16,
    dropout: float = 0.1,
    lora_dtype: torch.dtype = torch.float16,
) -> dict[str, torch.Tensor]:
    lora_weights = {}
    adapter_config = {
        "peft_type": "LORA",
        "auto_mapping": None,
        "base_model_name_or_path": "dummy_model",
        "revision": None,
        "task_type": "CAUSAL_LM",
        "inference_mode": False,
        "r": rank,
        "lora_alpha": alpha,
        "lora_dropout": dropout,
        "fan_in_fan_out": False,
        "bias": "none",
        "modules_to_save": None,
        "init_lora_weights": True,
        "layers_to_transform": None,
        "layers_pattern": None,
        "target_modules": target_modules,
        "exclude_modules": None,
        "use_rslora": False,
        "use_dora": False,
        "loftq_config": None,
    }

    for module_name in target_modules:
        module = model
        for attr in module_name.split("."):
            module = getattr(module, attr)

        if hasattr(module, "input_size") and hasattr(module, "output_size"):
            in_features = module.input_size
            out_features = module.output_size

392
        elif hasattr(module, "embedding_dim") and hasattr(module, "num_embeddings"):
393
394
395
396
            # ParallelLMHead
            in_features = module.embedding_dim
            out_features = module.num_embeddings
        else:
397
            raise ValueError(f"Unable to determine dimensions for module {module_name}")
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

        lora_A = torch.randn(rank, in_features, dtype=lora_dtype)

        torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)

        lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)

        # PEFT style
        lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
        lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B

    config_path = os.path.join(save_dir, "adapter_config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(adapter_config, f, indent=2, ensure_ascii=False)

    weights_path = os.path.join(save_dir, "adapter_model.safetensors")
    save_file(lora_weights, weights_path)

    return lora_weights