fully_sharded_layers.py 11.9 KB
Newer Older
1
# pylint: disable=unused-argument
2
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
3
4
5
6
7
8
9
10
11
12
13
14

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
    tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
                              MergedColumnParallelLinearWithLoRA,
                              MergedQKVParallelLinearWithLora,
15
                              QKVParallelLinearWithLora,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
                              RowParallelLinearWithLoRA)

if TYPE_CHECKING:
    pass


def _fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
        return (can_replace(*args, **kwargs)
30
                and kwargs["lora_config"].fully_sharded_loras)
31
32
33
34

    return dec


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
    """ 
    For `ColumnParallelLinearWithLoRA` or classes that inherit from 
    `ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
    """
    assert (layer.n_slices == len(layer.lora_a_stacked) == len(
        layer.lora_b_stacked) == len(layer.output_slices))
    if layer.lora_bias_stacked is not None:
        assert layer.n_slices == len(layer.lora_bias_stacked)

    output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

    x = x.view(-1, x.shape[-1])
    output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape

    # Since communication is needed, the buffer is directly initialized as a
    # tensor rather than a tuple of tensor.
    buffers = torch.zeros(
        (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
        dtype=torch.float32,
        device=x.device,
    )

    layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
    buffers = tensor_model_parallel_all_gather(buffers)
    layer.punica_wrapper.add_expand(output,
                                    buffers,
                                    layer.lora_b_stacked,
                                    layer.lora_bias_stacked,
                                    layer.output_slices,
                                    offset_start=0,
                                    add_input=True)

    output = output.view(*out_orig_shape)
    # now have column partitioned and packed output
    return output


73
74
75
76
77
78
79
80
81
82
83
84
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.


class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

85
86
87
88
89
    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
90
91
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
92
        shard_size = self.lora_a_stacked[0].shape[2]
93
94
95
96
        start_idx = tp_rank * shard_size
        lora_a = lora_a[:, start_idx:start_idx + shard_size]
        return lora_a

97
98
99
100
    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)
101
102
103

    @classmethod
    @_fully_sharded_can_replace
104
105
106
107
108
109
110
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
111
112
113
114
115
116
117
118
119
120
121
122
123
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class MergedColumnParallelLinearWithShardedLoRA(
        MergedColumnParallelLinearWithLoRA):
    """
124
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
125
126
127
128
129
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

130
131
132
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
133
        #NOTE: lora_a contains 2 subloras, and each sublora could be None.
134
135
136
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        lora_a = [
137
138
139
140
            lora_a[0][:, output_start_idx:output_start_idx +
                      output_shard_size] if lora_a[0] is not None else None,
            lora_a[1][:, output_start_idx:output_start_idx +
                      output_shard_size] if lora_a[1] is not None else None,
141
142
143
        ]
        return lora_a

144
145
146
    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
147
        return _mcp_apply(x, bias, self)
148
149
150

    @classmethod
    @_fully_sharded_can_replace
151
152
153
154
155
156
157
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
158
159
160
161
162
163
164
165
166
167
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


168
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
169
    """
170
    Differs from QKVParallelLinearWithLora by slicing the
171
172
173
174
175
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

176
177
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
178
        shard_size = self.lora_a_stacked[0].shape[2]
179
180
181
182
        start_idx = tp_rank * shard_size
        lora_a = lora_a[:, start_idx:start_idx + shard_size]
        return lora_a

183
184
185
186
    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
    """
    Differs from MergedQKVParallelLinearWithLora by slicing the 
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

211
212
213
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
214
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
215
216
217
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
218
219
220
221
222
223
            lora_a[0][:, start_idx[0]:start_idx[0] +
                      shard_size[0]] if lora_a[0] is not None else None,
            lora_a[1][:, start_idx[1]:start_idx[1] +
                      shard_size[1]] if lora_a[1] is not None else None,
            lora_a[2][:, start_idx[2]:start_idx[2] +
                      shard_size[2]] if lora_a[2] is not None else None,
224
225
226
        ]
        return lora_a

227
228
229
    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
230
        return _mcp_apply(x, bias, self)
231
232
233

    @classmethod
    @_fully_sharded_can_replace
234
235
236
237
238
239
240
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
241
242
243
244
245
246
247
248
249
250
251
252
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
253
    Differs from RowParallelLinearWithLoRA by slicing the
254
255
256
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
257
    This yields a combined partial sum from the row parallel base
258
259
260
261
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
262
        shard_size = self.lora_b_stacked[0].shape[2]
263
264
265
266
267
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = lora_b[:, start_idx:end_idx]
        return lora_b

268
269
270
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        if bias is None:
            return bias
271
272
273
        self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                      self.lora_bias_stacked)
        shard_size = self.lora_bias_stacked[0].shape[2]
274
275
276
277
278
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

279
280
281
    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
282
        output = self.base_layer.quant_method.apply(self.base_layer, x)
283
284
285
286

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1,
                                             output.shape[-1]), output.shape
287
        buffer = torch.zeros(
288
            (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
289
290
291
292
293
            dtype=torch.float32,
            device=x.device,
        )

        self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
294
295
296
297
298
299
300
301
        buffer = tensor_model_parallel_all_reduce(buffer)

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
302
303
304
305
306
307
308
309
310
311
312
313
        # NOTE offset are based on the rank.
        shard_size = self.lora_b_stacked[0].shape[2]
        offset_start = self.tp_rank * shard_size
        self.punica_wrapper.add_expand(
            output,
            buffer,
            self.lora_b_stacked,
            self.lora_bias_stacked,
            self.output_slices,
            offset_start=offset_start,
            add_input=True,
        )
314
315
316
317
318
        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
319
320
321
322
323
324
325
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
326
327
328
329
330
331
332
333
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )