utils.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json
import os
6
from dataclasses import dataclass
7
8

import torch
9
from safetensors.torch import save_file
10

11
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
12
13
14
from vllm.platforms import current_platform

DEVICE_TYPE = current_platform.device_type
15
16
17


class DummyLoRAManager:
18
    def __init__(self, device: torch.device = f"{DEVICE_TYPE}:0"):
19
        super().__init__()
20
        self._loras: dict[str, LoRALayerWeights] = {}
21
        self._device = device
22
23
24
25

    def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
        self._loras[module_name] = lora

26
27
    def get_module_lora(self, module_name: str) -> LoRALayerWeights:
        return self._loras[module_name]
28

29
30
31
32
33
34
    def init_random_lora(
        self,
        module_name: str,
        weight: torch.Tensor,
        rank: int = 8,
    ):
35
36
37
38
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
39
40
41
42
43
44
            lora_a=torch.rand(
                [rank, weight.shape[1]], dtype=weight.dtype, device=self._device
            ),
            lora_b=torch.rand(
                [weight.shape[0], rank], dtype=weight.dtype, device=self._device
            ),
45
46
47
48
49
        )
        self.set_module_lora(module_name, lora)

        return lora

50
51
52
53
54
55
56
57
58
    def init_lora(
        self,
        module_name: str,
        input_dim: int,
        output_dim: int,
        rank=8,
        noop=False,
        embeddings_tensor=None,
    ):
59
60
61
62
        lora = LoRALayerWeights(
            module_name,
            rank=rank,
            lora_alpha=1,
63
64
            lora_a=torch.rand([rank, input_dim], device=DEVICE_TYPE),
            lora_b=torch.rand([output_dim, input_dim], device=DEVICE_TYPE),
65
66
67
68
69
70
71
72
73
74
75
76
            embeddings_tensor=embeddings_tensor,
        )
        self.set_module_lora(module_name, lora)
        return lora

    def reset_lora(self):
        self._loras = {}

    def init_packed_lora(
        self,
        module_name: str,
        input_dim: int,
77
        output_dims: list[int],
78
        noop_lora_index: list[int] | None = None,
79
        rank: int = 8,
80
    ):
81
        base_loras: list[LoRALayerWeights] = []
82
        noop_lora_index_set = set(noop_lora_index or [])
83
84
85
86
87
88
89

        for i, out_dim in enumerate(output_dims):
            base_lora = self.init_lora(
                module_name + "_000_" + str(i),
                input_dim,
                out_dim,
                rank=rank,
90
                noop=i in noop_lora_index_set,
91
92
93
94
95
            )
            base_loras.append(base_lora)
        packed_lora = PackedLoRALayerWeights.pack(base_loras)
        self.set_module_lora(module_name, packed_lora)
        return packed_lora
96
97
98
99
100
101
102
103
104
105
106


def assert_close(a, b):
    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[a.dtype]
    torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


107
108
109
@dataclass
class PunicaTensors:
    inputs_tensor: torch.Tensor
110
    lora_weights: torch.Tensor | list[torch.Tensor]
111
112
113
114
115
116
117
    our_out_tensor: torch.Tensor
    ref_out_tensor: torch.Tensor
    b_seq_start_loc: torch.Tensor
    prompt_lora_mapping: torch.Tensor
    seq_len_tensor: torch.Tensor
    token_lora_mapping: torch.Tensor

118
    def meta(self) -> tuple[int, int]:
119
120
121
122
123
124
125
126
127
128
129
130
131
        """
        Infer max_seq_length and token_nums from the tensors
        and return them.
        """
        max_seq_length = self.seq_len_tensor.max()
        token_nums = self.seq_len_tensor.sum().item()
        if isinstance(max_seq_length, tuple):
            max_seq_length = max_seq_length[0].item()
        else:
            max_seq_length = max_seq_length.item()
        return max_seq_length, token_nums


132
133
134
135
136
137
138
139
140
def generate_data(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    op_type,
    device,
141
) -> PunicaTensors:
142
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
143
144
145
146
147
148
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    if op_type == "shrink":
149
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
150
151
152
153
154
        lora_weights = torch.rand(
            (lora_nums, max_rank, hidden_size),  # col-major
            dtype=dtype,
        ).to(device)
        # shrink op need atomic_add, so output is initinized by 0
155
156
157
        ref_out_tensor = torch.zeros(
            (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device
        )
158
        # NOTE  shrink kernel using torch.float32 as output type
159
160
161
        our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to(
            device
        )
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    else:
        inputs_tensor = torch.rand(
            (total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        lora_weights = torch.rand(
            (lora_nums, hidden_size, max_rank),  # col-major
            dtype=dtype,
        ).to(device)
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
        ref_out_tensor = torch.rand(
            (total_tokens, hidden_size),
            dtype=dtype,
        ).to(device)
        # Ensure the same input.
        our_out_tensor = ref_out_tensor.clone()
179
180
181
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    ).to(device)
182
183
184
185
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
186
187
188
        indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_(
            lora_index
        )
189
        current_offset += seq_len_tensor[b_id].item()
190
191

    return PunicaTensors(
192
193
194
195
196
197
198
199
200
201
202
        inputs_tensor,
        lora_weights,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


203
204
205
206
207
208
209
210
211
def generate_data_for_expand_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    dtype,
    nslices,
    device,
212
) -> PunicaTensors:
213
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()
    inputs_tensor = torch.rand(
        (total_tokens, max_rank),
        dtype=dtype,
    ).to(device)
    lora_weights_lst = []
    for _ in range(nslices):
        lora_weights_lst.append(
            torch.rand(
                (lora_nums, hidden_size, max_rank),  # col-major
                dtype=dtype,
229
230
            ).to(device)
        )
231
232
    # expand op needs to complete y+=a@lora_b, so output is
    # initinized randomly
233
234
235
    ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to(
        device
    )
236
237
    # Ensure the same input.
    our_out_tensor = ref_out_tensor.clone()
238
239
240
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
241
242
243
244
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
245
246
247
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
248
249
250
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
251
    return PunicaTensors(
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )


def generate_data_for_nslices(
    batches,
    hidden_size,
    lora_nums,
    max_rank,
    seq_length,
    nslices,
    dtype,
    op_type,
    device,
273
) -> PunicaTensors:
274
    seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device)
275
276
277
278
279
280
281
282
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
        dim=0,
    ).to(device)
    total_tokens = seq_len_tensor.sum()

    lora_weights_lst = []
    if op_type == "shrink":
283
        inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device)
284
285
286
287
288
289
290

        for _ in range(nslices):
            if op_type == "shrink":
                lora_weights_lst.append(
                    torch.rand(
                        (lora_nums, max_rank, hidden_size),  # col-major
                        dtype=dtype,
291
292
                    ).to(device)
                )
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        # NOTE  shrink kernel using torch.float32 as output type
        # shrink op need atomic_add, so output is initinized by 0
        our_out_tensor = torch.zeros(
            (nslices, total_tokens, max_rank),
            dtype=torch.float32,
        ).to(device)
    else:
        inputs_tensor = torch.rand(
            (nslices, total_tokens, max_rank),
            dtype=dtype,
        ).to(device)
        for _ in range(nslices):
            lora_weights_lst.append(
                torch.rand(
                    (lora_nums, hidden_size, max_rank),  # col-major
                    dtype=dtype,
309
310
                ).to(device)
            )
311
312
        # expand op needs to complete y+=a@lora_b, so output is
        # initinized randomly
313
314
315
        our_out_tensor = torch.rand(
            (total_tokens, hidden_size * nslices), dtype=dtype
        ).to(device)
316
317
318

    # Ensure the same input.
    ref_out_tensor = our_out_tensor.clone()
319
320
321
    lora_indices_tensor = torch.randint(
        0, lora_nums - 1 if lora_nums > 1 else 1, (batches,)
    )
322
323
324
325
    indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
    current_offset = 0
    for b_id in range(batches):
        lora_index = lora_indices_tensor[b_id]
326
327
328
        indices[current_offset : current_offset + seq_len_tensor[b_id]] = (
            lora_index.item()
        )
329
330
331
        current_offset += seq_len_tensor[b_id].item()

    lora_indices_tensor = lora_indices_tensor.to(device)
332
    return PunicaTensors(
333
334
335
336
337
338
339
340
341
        inputs_tensor,
        lora_weights_lst,
        our_out_tensor,
        ref_out_tensor,
        b_seq_start_loc,
        lora_indices_tensor,
        seq_len_tensor,
        indices,
    )
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385


def create_peft_lora(
    model: torch.nn.Module,
    save_dir: str,
    target_modules: list[str],
    rank: int = 8,
    alpha: int = 16,
    dropout: float = 0.1,
    lora_dtype: torch.dtype = torch.float16,
) -> dict[str, torch.Tensor]:
    lora_weights = {}
    adapter_config = {
        "peft_type": "LORA",
        "auto_mapping": None,
        "base_model_name_or_path": "dummy_model",
        "revision": None,
        "task_type": "CAUSAL_LM",
        "inference_mode": False,
        "r": rank,
        "lora_alpha": alpha,
        "lora_dropout": dropout,
        "fan_in_fan_out": False,
        "bias": "none",
        "modules_to_save": None,
        "init_lora_weights": True,
        "layers_to_transform": None,
        "layers_pattern": None,
        "target_modules": target_modules,
        "exclude_modules": None,
        "use_rslora": False,
        "use_dora": False,
        "loftq_config": None,
    }

    for module_name in target_modules:
        module = model
        for attr in module_name.split("."):
            module = getattr(module, attr)

        if hasattr(module, "input_size") and hasattr(module, "output_size"):
            in_features = module.input_size
            out_features = module.output_size

386
        elif hasattr(module, "embedding_dim") and hasattr(module, "num_embeddings"):
387
388
389
390
            # ParallelLMHead
            in_features = module.embedding_dim
            out_features = module.num_embeddings
        else:
391
            raise ValueError(f"Unable to determine dimensions for module {module_name}")
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

        lora_A = torch.randn(rank, in_features, dtype=lora_dtype)

        torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)

        lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)

        # PEFT style
        lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
        lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B

    config_path = os.path.join(save_dir, "adapter_config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(adapter_config, f, indent=2, ensure_ascii=False)

    weights_path = os.path.join(save_dir, "adapter_model.safetensors")
    save_file(lora_weights, weights_path)

    return lora_weights