tensor_parallel.py 8.75 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
import torch
from torch.nn import functional as F
3
from typing import Iterable, List
Nicolas Patry's avatar
Nicolas Patry committed
4
from text_generation_server.layers.linear import get_linear, FastLinear
Nicolas Patry's avatar
Nicolas Patry committed
5
from text_generation_server.utils.import_utils import SYSTEM
Wang, Yi's avatar
Wang, Yi committed
6

Nicolas Patry's avatar
Nicolas Patry committed
7
if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
8
    import intel_extension_for_pytorch as ipex
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


class LayerConcat(torch.nn.Module):
    """
    Apply multiple layers to the input and concatenate their
    outputs.
    """

    def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
        """
        `dim` is the dimension along which layer outputs are concatenated.
        """
        super().__init__()
        self.layers = layers
        self.dim = dim

    def forward(self, x: torch.Tensor):
        outputs = [layer(x) for layer in self.layers]
        return torch.cat(outputs, self.dim)
Nicolas Patry's avatar
Nicolas Patry committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


class SuperLayer(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
    def __init__(self, linear, process_group, should_gather: bool):
        super().__init__(linear)
        self.process_group = process_group
        self.should_gather = should_gather

    @staticmethod
    def load(config, prefix: str, weights):
47
48
49
50
51
        if config.quantize == "exl2":
            try:
                # If the piece and LM head embeddings are shared, we have
                # non-quantized weights...
                weight = weights.get_tensor(f"{prefix}.weight")
52
            except Exception:
53
                # ...otherwise they are quantized.
54
                weight = weights.get_weights_col(prefix)
55
56
            should_gather = weights.process_group.size() > 1
        elif weights.process_group.size() > 1:
Nicolas Patry's avatar
Nicolas Patry committed
57
58
59
60
61
62
63
64
65
66
67
68
69
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False

        return TensorParallelHead(
70
            get_linear(weight, bias=None),
Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            process_group=weights.process_group,
            should_gather=should_gather,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if not self.should_gather:
            return super().forward(input)

        world_size = self.process_group.size()
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
            out_dim = self.linear.weight.shape[0]

            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T

            torch.mm(input, self.linear.weight.T, out=local_out)
Nicolas Patry's avatar
Nicolas Patry committed
93
            if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
94
95
96
97
98
99
100
                ipex.distributed.all_gather_into_tensor(
                    world_out, gather_input, group=self.process_group
                )
            else:
                torch.distributed.all_gather_into_tensor(
                    world_out, gather_input, group=self.process_group
                )
Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
104
105
106
107
108
109

            if input.shape[0] == 1:
                return world_out
            return world_out.T

        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
Nicolas Patry's avatar
Nicolas Patry committed
110
        if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
111
112
113
            ipex.distributed.all_gather(world_output, output, group=self.process_group)
        else:
            torch.distributed.all_gather(world_output, output, group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
114
115
116
117
118
119
120
121
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
    def load_gate_up(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
122
        weight = weights.get_weights_col_packed_gate_up(prefix)
Nicolas Patry's avatar
Nicolas Patry committed
123
124
125
126
        if bias:
            raise NotImplementedError("packed_gate_up only implemented without bias")
        else:
            bias = None
127
        linear = get_linear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
        return cls(linear)

    @classmethod
131
132
133
134
135
136
137
138
139
    def load_qkv(
        cls,
        config,
        prefix: str,
        weights,
        bias: bool,
        num_heads: int,
        num_key_value_heads: int,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
140
        """Specific method when the QKV was joined after the fact"""
141
142
143
144
145
        weight = weights.get_weights_col_packed_qkv(
            prefix,
            num_heads=num_heads,
            num_key_value_heads=num_key_value_heads,
        )
Nicolas Patry's avatar
Nicolas Patry committed
146
147
148
149
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
150
        linear = get_linear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
        return cls(linear)

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
155
        weight = weights.get_weights_col(prefix)
Nicolas Patry's avatar
Nicolas Patry committed
156
        if bias:
157
            bias = weights.get_sharded(f"{prefix}.bias", dim=0)
Nicolas Patry's avatar
Nicolas Patry committed
158
159
        else:
            bias = None
160
        linear = get_linear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
161
162
        return cls(linear)

163
164
165
166
167
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
        if config.quantize == "exl2":
            linears = []
            for prefix in prefixes:
168
                weight = weights.get_weights_col(prefix)
169
                b = weights.get_tensor(f"{prefix}.bias") if bias else None
170
                linears.append(get_linear(weight, b))
171
172
            linear = LayerConcat(linears)
        else:
173
            weight = weights.get_multi_weights_col(prefixes, dim=dim)
174
175
176
177
178
            if bias:
                b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
                bias = torch.cat(b, dim=dim)
            else:
                bias = None
179
            linear = get_linear(weight, bias)
180
181
        return cls(linear)

Nicolas Patry's avatar
Nicolas Patry committed
182
183
184
185
186
187
188
189

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
        self.process_group = process_group

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
190
        weight = weights.get_weights_row(prefix)
Nicolas Patry's avatar
Nicolas Patry committed
191
192
193
194
195
196
197

        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
198
            get_linear(weight, bias),
Nicolas Patry's avatar
Nicolas Patry committed
199
200
201
202
203
204
            process_group=weights.process_group,
        )

    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
        out = super().forward(input)
        if self.process_group.size() > 1 and reduce:
Nicolas Patry's avatar
Nicolas Patry committed
205
            if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
206
207
208
                ipex.distributed.all_reduce(out, group=self.process_group)
            else:
                torch.distributed.all_reduce(out, group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        return out


class TensorParallelEmbedding(torch.nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = (num_embeddings + world_size - 1) // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = weight.shape[
            0
        ]  # Usually block_size, might be less in non even vocab_size.
        self.process_group = weights.process_group
        self.reduce = reduce

        """Additional 0 entry used for masking"""
        self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
        out = torch.nn.functional.embedding(input, self.weight)
        if self.reduce and self.process_group.size() > 1:
Nicolas Patry's avatar
Nicolas Patry committed
245
            if SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
246
247
248
                ipex.distributed.all_reduce(out, group=self.process_group)
            else:
                torch.distributed.all_reduce(out, group=self.process_group)
Nicolas Patry's avatar
Nicolas Patry committed
249
        return out