weights.py 14.9 KB
Newer Older
jixx's avatar
init  
jixx committed
1
import torch
jixx's avatar
jixx committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open
from dataclasses import dataclass

from text_generation_server.utils.import_utils import SYSTEM


class WeightsLoader(ABC):
    """
    Instances of this type implement higher-level weight loading.

    At a low-level, every weight is stored in the Safetensors format.
    The interpretation of weights may be different however, for instance
    could be packed, quantized weights. Loaders are responsible for
    interpreting the raw tensors, sharding tensors in a manner compatible
    with the format, etc.
    """

    @abstractmethod
    def get_weights(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply without tensor paralllism.
        """
        ...

    @abstractmethod
    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        """
        Get the packed weights at the given prefix with column-splitting for
        tensor parallelism. This method should be used when multiple different
        weights are packed into a tensor, for instance, query/key/value
        weights or a gate/up projection.

        The `block_sizes` determines the proportions of the packed tensors.
        The columns are split in equally sized blocks when `block_sizes` is an
        `int`, or in blocks proportional given to the sizes. For instance
        `[2, 1, 1]` will divide an input with dimensionality `1024` in
        `[512, 256, 256]`.
        """
        ...

    def get_weights_col(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply column-splitting for tensor
        paralllism.
        """
        return weights.get_multi_weights_col([prefix], 0)

    @abstractmethod
    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        """
        Get the weights at the given prefixes, column-split them for tensor
        parallelim, and then concatenate the weights along the given dimension.
        """
        ...

    @abstractmethod
    def get_weights_row(self, weights: "Weights", prefix: str):
        """
        Get the weights at the given prefix and apply row-splitting for tensor
        parallism.
        """
        ...


class Weight(ABC):
    """Instances of this type implement unquantized/quantized/to-be
    quantized weights."""

    @abstractmethod
    def get_linear(self, bias: torch.Tensor):
        """Create a linear layer from this weight."""
        ...


@dataclass
class UnquantizedWeight(Weight):
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        from text_generation_server.layers.linear import FastLinear, FastLinearROCm

        if SYSTEM == "rocm":
            return FastLinearROCm(self.weight, bias)
        else:
            return FastLinear(self.weight, bias)


class DefaultWeightsLoader(WeightsLoader):
    """Weight loader that loads (unquantized) Torch tensors."""

    def __init__(self, weight_class: Type[UnquantizedWeight]):
        """Create a loader. Weights will be wrapped using the given `weights_class`,
        normally this will be `UnquantizedWeight`, but a quantizer-specific class
        such as `Fp8Weight` can be used to quantize the weights during loading.
        """
        self.weight_class = weight_class

    """
    Loader that uses tensors as-is with the exception of applying sharding
    and/or concatenation.
    """

    def get_weights(self, weights: "Weights", prefix: str):
        return weights.get_tensor(f"{prefix}.weight")

    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        return self.weight_class(
            weights.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
            ),
        )

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
        return self.weight_class(torch.cat(w, dim=dim))

    def get_weights_row(self, weights: "Weights", prefix: str):
        return self.weight_class(
            weights.get_sharded(f"{prefix}.weight", dim=1),
        )
jixx's avatar
init  
jixx committed
137
138
139
140
141
142
143
144
145


class Weights:
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
jixx's avatar
jixx committed
146
        weights_loader: WeightsLoader,
jixx's avatar
init  
jixx committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        aliases: Optional[Dict[str, List[str]]] = None,
        prefix: Optional[str] = None,
    ):
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
        if aliases is None:
            aliases = {}
        self.aliases = aliases
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self.prefix = prefix
jixx's avatar
jixx committed
167
        self.weights_loader = weights_loader
jixx's avatar
init  
jixx committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

    def get_filename(self, tensor_name: str) -> (str, str):
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
        raise RuntimeError(f"weight {tensor_name} does not exist")

    def _get_slice(self, tensor_name: str):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

jixx's avatar
jixx committed
200
201
202
203
204
205
206
    def has_tensor(self, tensor_name: str):
        try:
            self.get_filename(tensor_name)
        except Exception:
            return False
        return True

jixx's avatar
init  
jixx committed
207
208
209
    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

jixx's avatar
jixx committed
210
211
212
    def get_tensor(
        self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
    ) -> torch.Tensor:
jixx's avatar
init  
jixx committed
213
214
215
216
217
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32. Exl2 uses int16
jixx's avatar
jixx committed
218
219
220
221
222
223
224
225
226
227
228
        # as well. FP8 uses torch.float8_e4m3fn
        if (
            tensor.dtype
            not in [
                torch.float8_e4m3fn,
                torch.int16,
                torch.int32,
                torch.int64,
            ]
            and to_dtype
        ):
jixx's avatar
init  
jixx committed
229
230
231
232
233
            tensor = tensor.to(dtype=self.dtype)
        if to_device:
            tensor = tensor.to(device=self.device)
        return tensor

jixx's avatar
jixx committed
234
235
236
    def get_partial_sharded(
        self, tensor_name: str, dim: int, to_device=True, to_dtype=True
    ):
jixx's avatar
init  
jixx committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
        block_size = (size + world_size - 1) // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32. exl2 uses int16.
jixx's avatar
jixx committed
256
257
258
259
260
        # FP8 uses torch.float8_e4m3fn.
        if (
            tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
            and to_dtype
        ):
jixx's avatar
init  
jixx committed
261
            tensor = tensor.to(dtype=self.dtype)
jixx's avatar
jixx committed
262
263
        if to_device:
            tensor = tensor.to(device=self.device)
jixx's avatar
init  
jixx committed
264
265
        return tensor

jixx's avatar
jixx committed
266
    def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
jixx's avatar
init  
jixx committed
267
268
269
270
271
272
273
274
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
jixx's avatar
jixx committed
275
276
277
        return self.get_partial_sharded(
            tensor_name, dim, to_device=to_device, to_dtype=to_dtype
        )
jixx's avatar
init  
jixx committed
278
279

    def get_packed_sharded(
jixx's avatar
jixx committed
280
281
282
283
284
        self,
        tensor_name: str,
        dim: int,
        block_sizes: Union[int, List[int]],
        to_dtype=True,
jixx's avatar
init  
jixx committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

        world_size = self.process_group.size()
        rank = self.process_group.rank()

        tensors = []
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
            block_offset += block_size
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)

        # Avoid casting quantizer dtypes.
jixx's avatar
jixx committed
330
331
332
333
334
335
336
337
338
339
        if (
            tensor.dtype
            not in [
                torch.float8_e4m3fn,
                torch.int16,
                torch.int32,
                torch.int64,
            ]
            and to_dtype
        ):
jixx's avatar
init  
jixx committed
340
341
342
343
            tensor = tensor.to(dtype=self.dtype)

        return tensor

jixx's avatar
jixx committed
344
345
346
    def get_weights(self, prefix: str):
        return self.weights_loader.get_weights(self, prefix)

jixx's avatar
init  
jixx committed
347
348
349
350
351
352
353
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
jixx's avatar
jixx committed
354
            prefix, [num_heads, num_key_value_heads, num_key_value_heads]
jixx's avatar
init  
jixx committed
355
356
        )

jixx's avatar
jixx committed
357
358
    def get_weights_col_packed_gate_up(self, prefix: str):
        return self.get_weights_col_packed(prefix, 2)
jixx's avatar
init  
jixx committed
359

jixx's avatar
jixx committed
360
    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
jixx's avatar
init  
jixx committed
361
362
363
364
365
366
367
        """
        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
jixx's avatar
jixx committed
368
        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
jixx's avatar
init  
jixx committed
369

jixx's avatar
jixx committed
370
371
    def get_weights_col(self, prefix: str):
        return self.weights_loader.get_weights_col(self, prefix)
jixx's avatar
init  
jixx committed
372

jixx's avatar
jixx committed
373
374
    def get_multi_weights_col(self, prefixes: List[str], dim: int):
        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
jixx's avatar
init  
jixx committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor

jixx's avatar
jixx committed
392
393
    def get_weights_row(self, prefix: str):
        return self.weights_loader.get_weights_row(self, prefix)
jixx's avatar
init  
jixx committed
394

jixx's avatar
jixx committed
395
396
397
398
399
400
    @contextmanager
    def use_loader(self, weights_loader: WeightsLoader):
        """
        This method is a context manager that can be used to use `Weights` with
        a different loader for the duration of the context.
        """
jixx's avatar
init  
jixx committed
401

jixx's avatar
jixx committed
402
403
        old_loader = self.weights_loader
        self.weights_loader = weights_loader
jixx's avatar
init  
jixx committed
404
        try:
jixx's avatar
jixx committed
405
406
407
            yield
        finally:
            self.weights_loader = old_loader
jixx's avatar
init  
jixx committed
408

jixx's avatar
jixx committed
409
410
411
    @property
    def loader(self):
        return self.weights_loader
jixx's avatar
init  
jixx committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks