"vscode:/vscode.git/clone" did not exist on "b3710d2c93b6f1ef608990096d71817c5cf35608"
lora.py 15.5 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/adapters/lora.py
# License:  Apache License Version 2.0, January 2004

from collections import defaultdict
from dataclasses import dataclass
7
from typing import Dict, List, Optional, Set, Tuple, Type, Union
drbh's avatar
drbh committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup

from text_generation_server.adapters.config import AdapterConfig, ModuleMap

from text_generation_server.adapters.weights import (
    AdapterBatchMetadata,
    AdapterWeights,
    BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
    BGMV_MAX_RANK,
    MAX_RANK_CUSTOM,
    get_tmp_tensors,
    orient_for_rank,
    pad_rank,
    use_cutlass_shrink,
)


def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
    block_size = size // world_size
    start = offset + rank * block_size
    stop = offset + (rank + 1) * block_size
    return start, stop


def shard_on_dim(
    t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
    world_size = process_group.size()
    rank = process_group.rank()

    size = t.shape[dim]
    start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)

    if dim == 0:
        tensor = t[start:stop]
    elif dim == 1:
        tensor = t[:, start:stop]
    else:
        raise NotImplementedError("Let's make that generic when needed")

    return tensor


def shard_lora_weights(
    weights_a: List[torch.Tensor],
    weights_b: List[torch.Tensor],
    split_dim: int,
    process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    # [hidden_size, r]
    weights_a = [
        shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
    ]

    # [r, hidden_size]
    weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]

    return weights_a, weights_b


@dataclass
class LoraConfig(AdapterConfig):
    r: int
    target_modules: Optional[Union[List[str], str]]
    fan_in_fan_out: bool
    lora_alpha: int
    use_rslora: bool

    def map_weights_for_model(
        self,
        adapter_weights: Dict[int, AdapterWeights],
        weight_names: Tuple[str],
    ) -> Tuple[ModuleMap, Set[str]]:
        adapter_weight_names = set()
        module_map = {}
        for weight_name in weight_names:
            lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
            lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
            if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
                continue

            module_map[weight_name] = {
                "lora_A": (adapter_weights[lora_a_name], lora_a_name),
                "lora_B": (adapter_weights[lora_b_name], lora_b_name),
            }
            adapter_weight_names.add(lora_a_name)
            adapter_weight_names.add(lora_b_name)
        return module_map, adapter_weight_names

    @classmethod
    def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
        hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
        return cls(
            base_model_name_or_path=hf_config.base_model_name_or_path,
            r=hf_config.r,
            target_modules=hf_config.target_modules,
            fan_in_fan_out=hf_config.fan_in_fan_out,
            lora_alpha=hf_config.lora_alpha,
            use_rslora=(
                hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
            ),
        )


class LoraWeights(AdapterWeights):
    """LoRA weights for a single adapter merged across all layers."""

    def __init__(
        self,
        weights_a: List[torch.Tensor],
        weights_b: List[torch.Tensor],
        adapter_config: LoraConfig,
    ):
        self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
        self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

        self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
        self._is_transposed = False

        # [num_layers, hidden_size, r]
        weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
        self._weights_a = torch.stack(weights_a)

        # [num_layers, r, hidden_size]
        self._weights_b = torch.stack(weights_b)

        self.adapter_config = adapter_config

    @property
    def weights_a(self) -> torch.Tensor:
        if self._is_transposed:
            self._transpose_weights()
        return self._weights_a

    @property
    def weights_b(self) -> torch.Tensor:
        if self._is_transposed:
            self._transpose_weights()
        return self._weights_b

    @property
    def weights_a_t(self) -> torch.Tensor:
        if not self._is_transposed:
            self._transpose_weights()
        return self._weights_a

    @property
    def weights_b_t(self) -> torch.Tensor:
        if not self._is_transposed:
            self._transpose_weights()
        return self._weights_b

    def _transpose_weights(self):
        if self._use_cutlass_shrink:
            # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
            self._weights_a = self._weights_a.transpose(1, 2).contiguous()
        self._weights_b = self._weights_b.transpose(1, 2).contiguous()
        self._is_transposed = not self._is_transposed

    @classmethod
    def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
        return [BatchLoraWeights]

176
177
178
179
180
181
182
183
184
185
186
187
188
    # prepare pre-loaded lora weights for use in the model.
    #
    # this method processes and organizes lora weights for a specific layer type across all layers:
    # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
    # - retrieves weights from `module_map` based on the `layer_type`.
    # - processes `nlayers` number of layers.
    # - converts weights to the specified `dtype`.
    # - shards weights across `world_size` number of processes using the `process_group`.
    # - maps weights to specific layers using `target_to_layer`.
    # - tracks `unused_weight_names` to identify any unused weights.
    #
    # the method handles weight transposition, scaling, and padding to ensure compatibility
    # with SGMV or BGMV operations.
drbh's avatar
drbh committed
189
    @classmethod
190
    def prepare_weights(
drbh's avatar
drbh committed
191
192
193
194
195
        cls,
        config: LoraConfig,
        module_map: Dict[str, Dict],
        layer_type: str,
        unused_weight_names: Set[str],
196
197
198
199
200
        nlayers: int,
        dtype: torch.dtype,
        world_size: int,
        process_group: ProcessGroup,
        target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
drbh's avatar
drbh committed
201
202
203
204
205
206
    ) -> Optional[AdapterWeights]:
        lora_a_list = [None] * nlayers
        lora_b_list = [None] * nlayers

        for layer_id in range(nlayers):
            key = (layer_id, layer_type)
207
            weight_name, layer = target_to_layer[key]
drbh's avatar
drbh committed
208
209
210
211
212
213
214
215
            base_weight = layer.base_layer.linear.weight
            base_device = base_weight.device

            if weight_name not in module_map:
                # There is no LoRA weight for this layer type in the adapter
                return None

            lora_a, lora_a_name = module_map[weight_name]["lora_A"]
216
            lora_a = lora_a.to(base_device, dtype)
drbh's avatar
drbh committed
217
218

            lora_b, lora_b_name = module_map[weight_name]["lora_B"]
219
            lora_b = lora_b.to(base_device, dtype)
drbh's avatar
drbh committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

            scale = get_scaling_factor(
                config.lora_alpha,
                config.r,
                uses_rslora=config.use_rslora,
            )

            unused_weight_names.discard(lora_a_name)
            unused_weight_names.discard(lora_b_name)

            # Merge scaling factor into lora_b due to associativity of matrix multiplication:
            # (A * B) * C = A * (B * C)
            lora_a_list[layer_id] = lora_a.transpose(0, 1)
            lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

        # pad lora ranks to be compatible with sgmv
236
237
        lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
        lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
drbh's avatar
drbh committed
238
239
240
241
242
243
244
245
246
247

        if lora_a_list:
            # update rank if it was padded
            padded_rank = lora_a_list[0].size(1)
            config.r = padded_rank

        return LoraWeights(
            *shard_lora_weights(
                weights_a=lora_a_list,
                weights_b=lora_b_list,
248
249
                split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
                process_group=process_group,
drbh's avatar
drbh committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            ),
            config,
        )


@dataclass
class RankSegments:
    rank: int

    lora_a_ptr: torch.Tensor
    lora_b_ptr: torch.Tensor

    # prefill (sgmv)
    tmp_shrink: torch.Tensor
    tmp_expand: torch.Tensor
    segment_starts: torch.Tensor
    segment_ends: torch.Tensor

    # decode (bgmv)
    indices: torch.Tensor


@dataclass
class BatchLoraWeights(BatchAdapterWeights):
    lora_a: Dict[int, torch.Tensor]
    lora_b: Dict[int, torch.Tensor]
    adapter_index_configs: Dict[int, LoraConfig]
    rank_data: Dict[int, RankSegments]
    use_sgmv: bool

    def has_adapter(self, adapter_index: int) -> bool:
        return adapter_index in self.adapter_index_configs

    def can_vectorize(self, pg: ProcessGroup) -> bool:
        return all(
            rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
            for rank_data in self.rank_data.values()
        )

    @classmethod
    def load(
        self,
        adapter_weights: Dict[int, AdapterWeights],
        meta: AdapterBatchMetadata,
        prefill: bool,
        prefill_head_indices: Optional[torch.Tensor],
    ) -> Optional["BatchLoraWeights"]:
        adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
        adapter_weights = {
            k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
        }
        if not adapter_weights:
            return None

        first_weights = next(iter(adapter_weights.values()))
        device = first_weights.weights_a.device
        segment_indices = meta.segment_indices

        lora_a = {
            idx: adapter_weights[idx].weights_a
            for idx in segment_indices
            if idx in adapter_weights
        }
        lora_b = {
            idx: adapter_weights[idx].weights_b
            for idx in segment_indices
            if idx in adapter_weights
        }

        max_rank = max(
            (
                adapter_weights[idx].lora_a_r
                for idx in segment_indices
                if idx in adapter_weights
            ),
            default=0,
        )

        if prefill or max_rank > BGMV_MAX_RANK:
            use_sgmv = True
            lora_a_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_a.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
            lora_b_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_b.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
        else:
            use_sgmv = False
            lora_a_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_a_t.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
            lora_b_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_b_t.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )

        adapter_index_configs = {
            idx: adapter_weights[idx].adapter_config
            for idx in segment_indices
            if idx in adapter_weights
        }

        adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

        rank_indices = defaultdict(list)
        for segment_idx, adapter_idx in enumerate(segment_indices):
            if adapter_idx not in adapter_weights:
                continue
            rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)

        if prefill_head_indices is not None:
            j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
            for head_index in prefill_head_indices:
                # j cannot go out of bounds as that would mean there are tokens without corresponding adapters
                if head_index < meta.adapter_segments[j]:
                    prefill_head_segment_ends[-1] += 1
                else:
                    prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
                    prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
                    j += 1

        rank_data = {}
        for rank, indices in rank_indices.items():
            tmp_shrink = None
            tmp_expand = None
            segment_starts = None
            segment_ends = None
            batch_indices = None

            if use_sgmv:
                lora_a_ptr_indices = lora_a_ptr[indices]
                tmp_shrink, tmp_expand = get_tmp_tensors(
                    lora_a_ptr_indices.size(0), rank, device
                )
                segment_starts = meta.adapter_segments[indices]
                segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
                if prefill_head_indices is not None:
                    for i, segment_index in enumerate(indices):
                        segment_starts[i] = prefill_head_segment_starts[segment_index]
                        segment_ends[i] = prefill_head_segment_ends[segment_index]
            else:
                rank_indices = set(indices)
                batch_indices = [
                    adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
                ]
                batch_indices = [
                    idx if idx in rank_indices else -1 for idx in batch_indices
                ]
                batch_indices = torch.tensor(
                    batch_indices, dtype=torch.int64, device=device
                )

            rank_data[rank] = RankSegments(
                rank=rank,
                tmp_shrink=tmp_shrink,
                tmp_expand=tmp_expand,
                lora_a_ptr=lora_a_ptr[indices],
                lora_b_ptr=lora_b_ptr[indices],
                segment_starts=segment_starts,
                segment_ends=segment_ends,
                indices=batch_indices,
            )

        return BatchLoraWeights(
            lora_a=lora_a,
            lora_b=lora_b,
            adapter_index_configs=adapter_index_configs,
            rank_data=rank_data,
            use_sgmv=use_sgmv,
        )


def get_scaling_factor(
    lora_alpha: int,
    r: int,
    uses_rslora: bool = False,
) -> float:
    """Computes the scaling factor for the lora weights."""
    if uses_rslora:
        return lora_alpha / (r**0.5)
    return lora_alpha / r


def _convert_lora(v: AdapterWeights) -> AdapterWeights:
    if hasattr(v, "lora_weights"):
        return v.lora_weights
    return v