row_parallel_linear.py 6.71 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union, cast

import torch
import torch.nn as nn
from transformers import PretrainedConfig

10
from vllm.config.lora import LoRAConfig
11
12
13
14
from vllm.distributed import (
    split_tensor_along_last_dim,
    tensor_model_parallel_all_reduce,
)
Jee Jee Li's avatar
Jee Jee Li committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform

from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace


class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__(base_layer)

        # reset input_size
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        # There is only one LoRA layer.
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.input_size
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
36
        lora_a = lora_a[:, start_idx:end_idx]
Jee Jee Li's avatar
Jee Jee Li committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

    def forward(
        self, input_: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
65
66
                input_, num_partitions=self.tp_size
            )
Jee Jee Li's avatar
Jee Jee Li committed
67
68
69
70
            input_parallel = splitted_input[self.tp_rank].contiguous()

        # Matrix multiply.
        output_parallel = self.apply(input_parallel)
71
        if self.base_layer.reduce_results and self.tp_size > 1:
Jee Jee Li's avatar
Jee Jee Li committed
72
73
74
75
76
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
77
78
79
80
81
            output = (
                output_ + self.base_layer.bias
                if self.base_layer.bias is not None
                else output_
            )
Jee Jee Li's avatar
Jee Jee Li committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias

        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is RowParallelLinear


# The following layer is based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.

108

Jee Jee Li's avatar
Jee Jee Li committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
    Differs from RowParallelLinearWithLoRA by slicing the
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
    This yields a combined partial sum from the row parallel base
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
123
        lora_b = lora_b[start_idx:end_idx, :]
Jee Jee Li's avatar
Jee Jee Li committed
124
125
126
127
128
        return lora_b

    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        if bias is None:
            return bias
129
        self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked)
Jee Jee Li's avatar
Jee Jee Li committed
130
131
132
133
134
135
        shard_size = self.lora_bias_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

136
137
138
    def apply(
        self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
139
140
141
        output = self.base_layer.quant_method.apply(self.base_layer, x)

        x = x.view(-1, x.shape[-1])
142
        output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
Jee Jee Li's avatar
Jee Jee Li committed
143
144
145
146
147
148
149
        buffer = torch.zeros(
            (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
            dtype=torch.float32,
            device=x.device,
        )

        shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
150
151
            buffer, x, self.lora_a_stacked, 1.0
        )
Jee Jee Li's avatar
Jee Jee Li committed
152
153
        if not current_platform.can_update_inplace():
            buffer = shrunk_buffer
154
        if self.tp_size > 1:
155
            buffer = tensor_model_parallel_all_reduce(buffer)
Jee Jee Li's avatar
Jee Jee Li committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
        # NOTE offset are based on the rank.
        shard_size = self.lora_b_stacked[0].shape[2]
        offset_start = self.tp_rank * shard_size
        lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
            output,
            buffer,
            self.lora_b_stacked,
            self.lora_bias_stacked,
            self.output_slices,
            offset_start=offset_start,
            add_input=True,
        )

        if not current_platform.can_update_inplace():
            output = lora_output

        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )