weights.py 13.7 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager
3
4
from dataclasses import dataclass
from enum import Enum, auto
5
from pathlib import Path
6
from typing import Dict, List, Optional, Union
7

8
import torch
9
10
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


class WeightsLoader(ABC):
    """
    Instances of this type implement higher-level weight loading.

    At a low-level, every weight is stored in the Safetensors format.
    The interpretation of weights may be different however, for instance
    could be packed, quantized weights. Loaders are responsible for
    interpreting the raw tensors, sharding tensors in a manner compatible
    with the format, etc.
    """

    @abstractmethod
    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        """
        Get the packed weights at the given prefix with column-splitting for
        tensor parallelism. This method should be used when multiple different
        weights are packed into a tensor, for instance, query/key/value
        weights or a gate/up projection.

        The `block_sizes` determines the proportions of the packed tensors.
        The columns are split in equally sized blocks when `block_sizes` is an
        `int`, or in blocks proportional given to the sizes. For instance
        `[2, 1, 1]` will divide an input with dimensionality `1024` in
        `[512, 256, 256]`.
        """
        ...

    def get_weights_col(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply column-splitting for tensor
        paralllism.
        """
        return weights.get_multi_weights_col([prefix], 0)

    @abstractmethod
    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        """
        Get the weights at the given prefixes, column-split them for tensor
        parallelim, and then concatenate the weights along the given dimension.
        """
        ...

    @abstractmethod
    def get_weights_row(self, weights: "Weights", prefix: str):
        """
        Get the weights at the given prefix and apply row-splitting for tensor
        parallism.
        """
        ...


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class Weight(ABC):
    """Instances of this type implement unquantized/quantized/to-be
    quantized weights."""

    @abstractmethod
    def get_linear(self, bias: torch.Tensor):
        """Create a linear layer from this weight."""
        ...


@dataclass
class UnquantizedWeight:
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        from text_generation_server.layers.linear import FastLinear, FastLinearROCm

        if SYSTEM == "rocm":
            return FastLinearROCm(self.weight, bias)
        else:
            return FastLinear(self.weight, bias)


92
class DefaultWeightsLoader(WeightsLoader):
93
94
95
96
97
98
99
100
101
    """Weight loader that loads (unquantized) Torch tensors."""

    def __init__(self, weight_class):
        """Create a loader. Weights will be wrapped using the given `weights_class`,
        normally this will be `UnquantizedWeight`, but a quantizer-specific class
        such as `Fp8Weight` can be used to quantize the weights during loading.
        """
        self.weight_class = weight_class

102
103
104
105
106
107
108
109
110
111
112
    """
    Loader that uses tensors as-is with the exception of applying sharding
    and/or concatenation.
    """

    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
113
114
115
116
117

        return self.weight_class(
            weights.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
            ),
118
119
120
121
        )

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
122
        return self.weight_class(torch.cat(w, dim=dim))
123
124

    def get_weights_row(self, weights: "Weights", prefix: str):
125
126
127
        return self.weight_class(
            weights.get_sharded(f"{prefix}.weight", dim=1),
        )
128
129
130


class Weights:
131
132
133
134
135
136
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
137
        weights_loader: WeightsLoader,
138
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
139
        prefix: Optional[str] = None,
140
    ):
141
142
143
144
145
146
147
148
149
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
150
151
152
        if aliases is None:
            aliases = {}
        self.aliases = aliases
153
154
155
156
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
157
        self.prefix = prefix
158
        self.weights_loader = weights_loader
159
160
161
162
163
164
165
166
167
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

168
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
169
170
171
172
173
174
175
176
177
178
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
179
180
181
182
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
183
        raise RuntimeError(f"weight {tensor_name} does not exist")
184
185

    def _get_slice(self, tensor_name: str):
186
        filename, tensor_name = self.get_filename(tensor_name)
187
188
189
190
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

191
192
193
194
195
196
197
    def _has_tensor(self, tensor_name: str):
        try:
            self.get_filename(tensor_name)
        except Exception:
            return False
        return True

198
199
200
    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
201
    def get_tensor(self, tensor_name: str, to_device=True):
202
        filename, tensor_name = self.get_filename(tensor_name)
203
204
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
205
        # Special case for gptq which shouldn't convert
206
207
208
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
209
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
210
211
        if to_device:
            tensor = tensor.to(device=self.device)
212
213
        return tensor

214
    def get_partial_sharded(self, tensor_name: str, dim: int):
215
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
216
217
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
218
219
220
221
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
222
        block_size = (size + world_size - 1) // world_size
223
224
225
226
227
228
229
230
231
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
232
        # Special case for gptq which shouldn't convert
233
234
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
235
            tensor = tensor.to(dtype=self.dtype)
236
237
        tensor = tensor.to(device=self.device)
        return tensor
238

239
240
241
242
243
244
245
246
247
248
249
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    def get_packed_sharded(
        self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
271
272
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
273
274
275
        world_size = self.process_group.size()
        rank = self.process_group.rank()

276
        tensors = []
277
278
279
280
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
281
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
282
283
284
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
285
286
287
288
289
290
291
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
292
            block_offset += block_size
293
294
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)
295

296
297
298
299
300
        # Avoid casting quantizer dtypes.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)

        return tensor
xiaobin's avatar
xiaobin committed
301

302
303
304
305
306
307
308
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
309
            prefix, [num_heads, num_key_value_heads, num_key_value_heads]
310
        )
Nicolas Patry's avatar
Nicolas Patry committed
311

312
313
    def get_weights_col_packed_gate_up(self, prefix: str):
        return self.get_weights_col_packed(prefix, 2)
Nicolas Patry's avatar
Nicolas Patry committed
314

315
    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
xiaobin's avatar
xiaobin committed
316
        """
317
318
319
320
321
        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
322
        """
323
        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
324

325
326
    def get_weights_col(self, prefix: str):
        return self.weights_loader.get_weights_col(self, prefix)
327

328
329
    def get_multi_weights_col(self, prefixes: List[str], dim: int):
        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
OlivierDehaene's avatar
OlivierDehaene committed
330

xiaobin's avatar
xiaobin committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
345
        return tensor
346

347
348
    def get_weights_row(self, prefix: str):
        return self.weights_loader.get_weights_row(self, prefix)
349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
    @contextmanager
    def use_loader(self, weights_loader: WeightsLoader):
        """
        This method is a context manager that can be used to use `Weights` with
        a different loader for the duration of the context.
        """

        old_loader = self.weights_loader
        self.weights_loader = weights_loader
        try:
            yield
        finally:
            self.weights_loader = old_loader

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks