lora.py 10.7 KB
Newer Older
1
from typing import TYPE_CHECKING, Optional, List
drbh's avatar
drbh committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

import torch
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup

from text_generation_server.utils.sgmv import (
    add_lora_a_bgmv,
    add_lora_b_bgmv,
    has_sgmv,
    lora_a_sgmv_cutlass,
    lora_b_sgmv_cutlass,
    orient_for_rank,
)

if TYPE_CHECKING:
    from text_generation_server.adapters import AdapterBatchData
    from text_generation_server.adapters.lora import BatchLoraWeights


class LoraLinear(nn.Module):
    def __init__(
        self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
    ):
        super().__init__()
        self.base_layer = base_layer
        self.layer_id = layer_id
        self.process_group = process_group

    def forward_layer_type(
        self,
        result: torch.Tensor,
        input: torch.Tensor,
        adapter_data: "AdapterBatchData",
        layer_type: str,
        start_idx: int,
        end_idx: int,
    ) -> torch.Tensor:
        if adapter_data is None:
            return result
42
        data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
drbh's avatar
drbh committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
            # In tensor-parallel configurations, each GPU processes a specific segment of the output.
            # The 'result' tensor represents the full output, which can vary in size based on
            # the layer type (e.g., attention vs. feed-forward layers). We define the current
            # segment using start_idx and end_idx. If the segment size doesn't match this GPU's
            # slice of 'result', we create a zero tensor of the correct size for LoRA computation.
            # This approach ensures accurate LoRA application across various layer sizes and
            # configurations, adapting to different model architectures and parallelization strategies.
            #
            # Example scenarios where this is necessary:
            # 1. The adapter's size doesn't evenly divide across GPUs.
            # 2. We're processing the last segment which might be smaller.
            # 3. Different projection layers (q, k, v) have different sizes.
            if end_idx - start_idx != result.shape[1]:
                proj = torch.zeros_like(result[:, start_idx:end_idx])
            else:
                proj = result

            for r, rank_segments in data.rank_data.items():
                lora_a_ptr = rank_segments.lora_a_ptr
                lora_b_ptr = rank_segments.lora_b_ptr

                if lora_a_ptr is None or lora_b_ptr is None:
                    raise ValueError("LoRA data is missing")

                if data.use_sgmv:
                    # Use SGMV for prefill
                    v = lora_a_sgmv_cutlass(
                        input,
                        rank_segments.tmp_shrink,
                        lora_a_ptr,
                        rank_segments.segment_starts,
                        rank_segments.segment_ends,
                        self.layer_id,
                        r,
                    )

                    if self.process_group.size() > 1:
                        v = self.collect_lora_a(v)

                    lora_b_sgmv_cutlass(
                        proj,
                        v,
                        rank_segments.tmp_expand,
                        lora_b_ptr,
                        rank_segments.segment_starts,
                        rank_segments.segment_ends,
                        self.layer_id,
                    )
                else:
                    # Use BGMV for decode
                    v = torch.zeros(
                        (input.size(0), r), dtype=input.dtype, device=input.device
                    )
                    # TODO: error with [-1, 0], but not [0, -1]
                    add_lora_a_bgmv(
                        v,
                        input,
                        lora_a_ptr,
                        rank_segments.indices,
                        self.layer_id,
                    )

                    if self.process_group.size() > 1:
                        v = self.collect_lora_a(v)

                    add_lora_b_bgmv(
                        proj,
                        v,
                        lora_b_ptr,
                        rank_segments.indices,
                        self.layer_id,
                    )

            if end_idx - start_idx != result.shape[1]:
                result[:, start_idx:end_idx] += proj
        else:
            for adapter_index in adapter_data.meta.adapter_set:
                if data is not None and data.has_adapter(adapter_index):
                    adapter_mask = (
                        (adapter_data.meta.adapter_indices == adapter_index)
                        .to(input.dtype)
                        .view(-1, 1)
                    )
                    layer_result = self.forward_lora(
                        input, data, adapter_index, adapter_mask
                    )
                    result[:, start_idx:end_idx] += layer_result

        return result

    def forward_lora(
        self,
        input: torch.Tensor,
        data: "BatchLoraWeights",
        adapter_index: int,
        adapter_mask: torch.Tensor,
    ) -> torch.Tensor:
        lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
        lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

        lora_a = orient_for_rank(lora_a, lora_b.size(0))

        a_out = input @ lora_a
        if self.process_group.size() > 1:
            a_out = self.collect_lora_a(a_out)

        result = (a_out @ lora_b) * adapter_mask
        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError("Implemented in subclasses")


class TensorParallelMultiAdapterLinear(LoraLinear):
    def __init__(
        self,
        base_layer: nn.Module,
        layer_id: int,
        layer_names: List[str],
        sizes: List[int],
        process_group: ProcessGroup,
    ):
        super().__init__(base_layer, layer_id, process_group)
        self.layer_names = layer_names
        self.sizes = sizes

    @classmethod
    def load(
        cls,
        base_layer: nn.Module,
        layer_id: int,
        layer_names: List[str],
        sizes: List[int],
        process_group: ProcessGroup,
    ):
        return TensorParallelMultiAdapterLinear(
            base_layer, layer_id, layer_names, sizes, process_group
        )

    def forward(
        self, input: torch.Tensor, adapter_data: "AdapterBatchData"
    ) -> torch.Tensor:
        result = self.base_layer(input)

        # noop if no layer names are provided (e.g. for models without adapters)
        if self.layer_names is None:
            return result

        # handle models like Bloom that have inputs of shape
        # (batch_size, sequence_length, hidden_size)
        # we need to reshape them to (batch_size * sequence_length, hidden_size)
        # for the LoRA computation, then reshape back
        prev_shape = result.shape
        is_3d = len(input.shape) >= 3
        if is_3d:
            input = input.reshape(-1, input.shape[-1])
            result = result.reshape(-1, result.shape[-1])

        offset = 0
        for i, layer_name in enumerate(self.layer_names):
            start_idx = offset // self.process_group.size()
            # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
            # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
            # ensures correct slicing of the result tensor, accommodating variations like grouped-query
            # attention where k_proj and v_proj differ from q_proj. This allows precise application of
            # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
            # different projection sizes across layers and model architectures.
            if self.sizes is not None:
                offset += self.sizes[i]
                end_idx = offset // self.process_group.size()
            else:
                end_idx = result.shape[1]

            result = self.forward_layer_type(
                result, input, adapter_data, layer_name, start_idx, end_idx
            )

        if is_3d:
            result = result.reshape(prev_shape)

        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
        # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
        #
        # TODO(travis): this is not very efficient as we do an all-gather for every adapter,
        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:
        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
        gathered_tensors = [
            torch.empty_like(a_out) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(gathered_tensors, a_out)
        return torch.cat(gathered_tensors, dim=1)


class TensorParallelAdapterRowLinear(LoraLinear):
    def __init__(self, base_layer, layer_id, layer_name, process_group):
        super().__init__(base_layer, layer_id, process_group)
        self.layer_name = layer_name

    @classmethod
    def load(cls, base_layer, layer_id, layer_name, process_group):
        return cls(base_layer, layer_id, layer_name, process_group)

    def forward(
        self, input: torch.Tensor, adapter_data: "AdapterBatchData"
    ) -> torch.Tensor:
        result = self.base_layer(input)

        if self.layer_name is None:
            return result

        # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
        stride = result.shape[-1] // self.process_group.size()
        start_idx = self.process_group.rank() * stride
        end_idx = (self.process_group.rank() + 1) * stride

        self.forward_layer_type(
            result, input, adapter_data, self.layer_name, start_idx, end_idx
        )

        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
        # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
        #
        # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:
        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
        torch.distributed.all_reduce(a_out, group=self.process_group)
        return a_out