weights.py 14.7 KB
Newer Older
1
2
import torch

3
from abc import ABC, abstractmethod
4
from contextlib import contextmanager
5
from pathlib import Path
6
from typing import Dict, List, Optional, Union, Type
7
from safetensors import safe_open
8
9
from dataclasses import dataclass

10
from text_generation_server.utils.import_utils import SYSTEM
11
12
13
14
15
16
17
18
19
20
21
22
23


class WeightsLoader(ABC):
    """
    Instances of this type implement higher-level weight loading.

    At a low-level, every weight is stored in the Safetensors format.
    The interpretation of weights may be different however, for instance
    could be packed, quantized weights. Loaders are responsible for
    interpreting the raw tensors, sharding tensors in a manner compatible
    with the format, etc.
    """

24
25
26
27
28
29
30
    @abstractmethod
    def get_weights(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply without tensor paralllism.
        """
        ...

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    @abstractmethod
    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        """
        Get the packed weights at the given prefix with column-splitting for
        tensor parallelism. This method should be used when multiple different
        weights are packed into a tensor, for instance, query/key/value
        weights or a gate/up projection.

        The `block_sizes` determines the proportions of the packed tensors.
        The columns are split in equally sized blocks when `block_sizes` is an
        `int`, or in blocks proportional given to the sizes. For instance
        `[2, 1, 1]` will divide an input with dimensionality `1024` in
        `[512, 256, 256]`.
        """
        ...

    def get_weights_col(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply column-splitting for tensor
        paralllism.
        """
        return weights.get_multi_weights_col([prefix], 0)

    @abstractmethod
    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        """
        Get the weights at the given prefixes, column-split them for tensor
        parallelim, and then concatenate the weights along the given dimension.
        """
        ...

    @abstractmethod
    def get_weights_row(self, weights: "Weights", prefix: str):
        """
        Get the weights at the given prefix and apply row-splitting for tensor
        parallism.
        """
        ...


76
77
78
79
80
81
82
83
84
85
86
class Weight(ABC):
    """Instances of this type implement unquantized/quantized/to-be
    quantized weights."""

    @abstractmethod
    def get_linear(self, bias: torch.Tensor):
        """Create a linear layer from this weight."""
        ...


@dataclass
87
class UnquantizedWeight(Weight):
88
89
90
91
92
93
94
95
96
97
98
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        from text_generation_server.layers.linear import FastLinear, FastLinearROCm

        if SYSTEM == "rocm":
            return FastLinearROCm(self.weight, bias)
        else:
            return FastLinear(self.weight, bias)


99
class DefaultWeightsLoader(WeightsLoader):
100
101
    """Weight loader that loads (unquantized) Torch tensors."""

102
    def __init__(self, weight_class: Type[UnquantizedWeight]):
103
104
105
106
107
108
        """Create a loader. Weights will be wrapped using the given `weights_class`,
        normally this will be `UnquantizedWeight`, but a quantizer-specific class
        such as `Fp8Weight` can be used to quantize the weights during loading.
        """
        self.weight_class = weight_class

109
110
111
112
113
    """
    Loader that uses tensors as-is with the exception of applying sharding
    and/or concatenation.
    """

114
115
116
    def get_weights(self, weights: "Weights", prefix: str):
        return weights.get_tensor(f"{prefix}.weight")

117
118
119
120
121
122
    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
123
124
125
126
127

        return self.weight_class(
            weights.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
            ),
128
129
130
131
        )

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
132
        return self.weight_class(torch.cat(w, dim=dim))
133
134

    def get_weights_row(self, weights: "Weights", prefix: str):
135
136
137
        return self.weight_class(
            weights.get_sharded(f"{prefix}.weight", dim=1),
        )
138
139
140


class Weights:
141
142
143
144
145
146
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
147
        weights_loader: WeightsLoader,
148
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
149
        prefix: Optional[str] = None,
150
    ):
151
152
153
154
155
156
157
158
159
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
160
161
162
        if aliases is None:
            aliases = {}
        self.aliases = aliases
163
164
165
166
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
167
        self.prefix = prefix
168
        self.weights_loader = weights_loader
169
170
171
172
173
174
175
176
177
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

178
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
179
180
181
182
183
184
185
186
187
188
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
189
190
191
192
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
193
        raise RuntimeError(f"weight {tensor_name} does not exist")
194
195

    def _get_slice(self, tensor_name: str):
196
        filename, tensor_name = self.get_filename(tensor_name)
197
198
199
200
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

201
202
203
204
205
206
207
    def _has_tensor(self, tensor_name: str):
        try:
            self.get_filename(tensor_name)
        except Exception:
            return False
        return True

208
209
210
    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

211
    def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
212
        filename, tensor_name = self.get_filename(tensor_name)
213
214
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
215
        # Special case for gptq which shouldn't convert
216
        # u4 which are disguised as int32. Exl2 uses int16
217
218
219
220
221
222
223
224
225
226
227
        # as well. FP8 uses torch.float8_e4m3fn
        if (
            tensor.dtype
            not in [
                torch.float8_e4m3fn,
                torch.int16,
                torch.int32,
                torch.int64,
            ]
            and to_dtype
        ):
228
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
229
230
        if to_device:
            tensor = tensor.to(device=self.device)
231
232
        return tensor

233
    def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
234
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
235
236
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
237
238
239
240
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
241
        block_size = (size + world_size - 1) // world_size
242
243
244
245
246
247
248
249
250
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
251
        # Special case for gptq which shouldn't convert
252
        # u4 which are disguised as int32. exl2 uses int16.
253
254
255
256
257
        # FP8 uses torch.float8_e4m3fn.
        if (
            tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
            and to_dtype
        ):
258
            tensor = tensor.to(dtype=self.dtype)
259
260
        tensor = tensor.to(device=self.device)
        return tensor
261

262
    def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
263
264
265
266
267
268
269
270
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
271
        return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
272

273
    def get_packed_sharded(
274
275
276
277
278
        self,
        tensor_name: str,
        dim: int,
        block_sizes: Union[int, List[int]],
        to_dtype=True,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
298
299
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
300
301
302
        world_size = self.process_group.size()
        rank = self.process_group.rank()

303
        tensors = []
304
305
306
307
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
308
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
309
310
311
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
312
313
314
315
316
317
318
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
319
            block_offset += block_size
320
321
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)
322

323
        # Avoid casting quantizer dtypes.
324
325
326
327
328
329
330
331
332
333
        if (
            tensor.dtype
            not in [
                torch.float8_e4m3fn,
                torch.int16,
                torch.int32,
                torch.int64,
            ]
            and to_dtype
        ):
334
335
336
            tensor = tensor.to(dtype=self.dtype)

        return tensor
xiaobin's avatar
xiaobin committed
337

338
339
340
    def get_weights(self, prefix: str):
        return self.weights_loader.get_weights(self, prefix)

341
342
343
344
345
346
347
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
348
            prefix, [num_heads, num_key_value_heads, num_key_value_heads]
349
        )
Nicolas Patry's avatar
Nicolas Patry committed
350

351
352
    def get_weights_col_packed_gate_up(self, prefix: str):
        return self.get_weights_col_packed(prefix, 2)
Nicolas Patry's avatar
Nicolas Patry committed
353

354
    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
xiaobin's avatar
xiaobin committed
355
        """
356
357
358
359
360
        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
361
        """
362
        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
363

364
365
    def get_weights_col(self, prefix: str):
        return self.weights_loader.get_weights_col(self, prefix)
366

367
368
    def get_multi_weights_col(self, prefixes: List[str], dim: int):
        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
OlivierDehaene's avatar
OlivierDehaene committed
369

xiaobin's avatar
xiaobin committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
384
        return tensor
385

386
387
    def get_weights_row(self, prefix: str):
        return self.weights_loader.get_weights_row(self, prefix)
388

389
390
391
392
393
394
395
396
397
398
399
400
401
402
    @contextmanager
    def use_loader(self, weights_loader: WeightsLoader):
        """
        This method is a context manager that can be used to use `Weights` with
        a different loader for the duration of the context.
        """

        old_loader = self.weights_loader
        self.weights_loader = weights_loader
        try:
            yield
        finally:
            self.weights_loader = old_loader

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks