fully_sharded_layers.py 13.1 KB
Newer Older
1
# pylint: disable=unused-argument
2
from typing import TYPE_CHECKING, List, Optional, Union
3
4
5
6
7
8
9
10
11
12
13
14

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
    tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
                              MergedColumnParallelLinearWithLoRA,
                              MergedQKVParallelLinearWithLora,
15
                              QKVParallelLinearWithLora,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
                              RowParallelLinearWithLoRA)

if TYPE_CHECKING:
    pass


def _fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
        return (can_replace(*args, **kwargs)
30
                and kwargs["lora_config"].fully_sharded_loras)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    return dec


# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.


class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

47
48
49
50
51
    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
52
53
54
55
56
57
58
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        shard_size = self.lora_a_stacked.shape[2]
        start_idx = tp_rank * shard_size
        lora_a = lora_a[:, start_idx:start_idx + shard_size]
        return lora_a

59
60
61
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
62
63
64
65

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1,
                                             output.shape[-1]), output.shape
66
67
68
69
70
71
        buffer = torch.zeros(
            (x.shape[0], self.lora_a_stacked.shape[2]),
            dtype=torch.float32,
            device=x.device,
        )
        self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
72
        buffer = tensor_model_parallel_all_gather(buffer)
73
74
75
        self.punica_wrapper.add_expand(output,
                                       buffer,
                                       self.lora_b_stacked,
76
                                       self.bias_stacked,
77
                                       add_input=True)
78
        # now have column partitioned output
79

80
81
82
83
84
        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
85
86
87
88
89
90
91
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
92
93
94
95
96
97
98
99
100
101
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


102
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
103
    """
104
105
    MergedColumnParallelLinearWithShardedLoRA and
    MergedQKVParallelLinearWithShardedLora share the same
106
107
108
    LoRa weight application method.
    
    The main difference is the step by shard_size for lora_b which can
109
    vary for MergedQKVParallelLinearWithShardedLora but is constant for
110
111
112
113
    MergedColumnParallelLinearWithShardedLoRA.
    """
    # expecting 2 for column parallel and 3 for qkv
    n = len(layer.lora_a_stacked)
114
    output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
115
116
117

    x = x.view(-1, x.shape[-1])
    output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
118
119
120
121
122
    buffers = torch.zeros(
        (n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
        dtype=torch.float32,
        device=x.device,
    )
123
    for idx in range(n):
124
125
        layer.punica_wrapper.add_shrink(buffers[idx], x,
                                        layer.lora_a_stacked[idx], 1.0)
126
127

    buffers = tensor_model_parallel_all_gather(buffers)
128
129
130
131
132
133
134
135
    layer.punica_wrapper.add_expand_packed_nslice(
        output,
        buffers,
        layer.lora_b_stacked,
        layer.bias_stacked,
        1.0,
        layer.output_slices,
    )
136
137
138
139
140
141
142
143
144

    output = output.view(*out_orig_shape)
    # now have column partitioned and packed output
    return output


class MergedColumnParallelLinearWithShardedLoRA(
        MergedColumnParallelLinearWithLoRA):
    """
145
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
146
147
148
149
150
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

151
152
153
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
154
        #NOTE: lora_a contains 2 subloras, and each sublora could be None.
155
156
157
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        lora_a = [
158
159
160
161
            lora_a[0][:, output_start_idx:output_start_idx +
                      output_shard_size] if lora_a[0] is not None else None,
            lora_a[1][:, output_start_idx:output_start_idx +
                      output_shard_size] if lora_a[1] is not None else None,
162
163
164
        ]
        return lora_a

165
166
167
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        return _mcp_apply(x, bias, self)
168
169
170

    @classmethod
    @_fully_sharded_can_replace
171
172
173
174
175
176
177
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
178
179
180
181
182
183
184
185
186
187
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


188
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
189
    """
190
    Differs from QKVParallelLinearWithLora by slicing the
191
192
193
194
195
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        shard_size = self.lora_a_stacked.shape[2]
        start_idx = tp_rank * shard_size
        lora_a = lora_a[:, start_idx:start_idx + shard_size]
        return lora_a

    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1,
                                             output.shape[-1]), output.shape
        buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
                             dtype=torch.float32,
                             device=x.device)
213
        self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
214
        buffer = tensor_model_parallel_all_gather(buffer)
215
216
217
        self.punica_wrapper.add_expand(output,
                                       buffer,
                                       self.lora_b_stacked,
218
                                       self.bias_stacked,
219
                                       add_input=True)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        # now have column partitioned output
        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
    """
    Differs from MergedQKVParallelLinearWithLora by slicing the 
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

247
248
249
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
250
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
251
252
253
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
254
255
256
257
258
259
            lora_a[0][:, start_idx[0]:start_idx[0] +
                      shard_size[0]] if lora_a[0] is not None else None,
            lora_a[1][:, start_idx[1]:start_idx[1] +
                      shard_size[1]] if lora_a[1] is not None else None,
            lora_a[2][:, start_idx[2]:start_idx[2] +
                      shard_size[2]] if lora_a[2] is not None else None,
260
261
262
        ]
        return lora_a

263
264
265
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        return _mcp_apply(x, bias, self)
266
267
268

    @classmethod
    @_fully_sharded_can_replace
269
270
271
272
273
274
275
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
276
277
278
279
280
281
282
283
284
285
286
287
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
288
    Differs from RowParallelLinearWithLoRA by slicing the
289
290
291
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
292
    This yields a combined partial sum from the row parallel base
293
294
295
296
297
298
299
300
301
302
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_b_stacked.shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = lora_b[:, start_idx:end_idx]
        return lora_b

303
304
305
306
307
308
309
310
311
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        if bias is None:
            return bias
        shard_size = self.bias_stacked.shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

312
313
    def apply(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x)
314
315
316
317

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1,
                                             output.shape[-1]), output.shape
318
319
320
321
322
323
324
        buffer = torch.zeros(
            (x.shape[0], self.lora_a_stacked.shape[2]),
            dtype=torch.float32,
            device=x.device,
        )

        self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
325
326
327
328
329
330
331
332
333
334
        buffer = tensor_model_parallel_all_reduce(buffer)

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
        shard_size = self.lora_b_stacked.shape[2]
        start_idx = self.tp_rank * shard_size
335
        self.punica_wrapper.add_expand_slice(output, buffer,
336
337
                                             self.lora_b_stacked,
                                             self.bias_stacked, start_idx,
338
                                             shard_size)
339
340
341
342
343
        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
344
345
346
347
348
349
350
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
351
352
353
354
355
356
357
358
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )