weights.py 12.5 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager
3
from pathlib import Path
4
from typing import Dict, List, Optional, Union
5
from safetensors import safe_open
6
import torch
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86


class WeightsLoader(ABC):
    """
    Instances of this type implement higher-level weight loading.

    At a low-level, every weight is stored in the Safetensors format.
    The interpretation of weights may be different however, for instance
    could be packed, quantized weights. Loaders are responsible for
    interpreting the raw tensors, sharding tensors in a manner compatible
    with the format, etc.
    """

    @abstractmethod
    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        """
        Get the packed weights at the given prefix with column-splitting for
        tensor parallelism. This method should be used when multiple different
        weights are packed into a tensor, for instance, query/key/value
        weights or a gate/up projection.

        The `block_sizes` determines the proportions of the packed tensors.
        The columns are split in equally sized blocks when `block_sizes` is an
        `int`, or in blocks proportional given to the sizes. For instance
        `[2, 1, 1]` will divide an input with dimensionality `1024` in
        `[512, 256, 256]`.
        """
        ...

    def get_weights_col(self, weights: "Weights", prefix: str):
        """
        Get weights at the given prefix and apply column-splitting for tensor
        paralllism.
        """
        return weights.get_multi_weights_col([prefix], 0)

    @abstractmethod
    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        """
        Get the weights at the given prefixes, column-split them for tensor
        parallelim, and then concatenate the weights along the given dimension.
        """
        ...

    @abstractmethod
    def get_weights_row(self, weights: "Weights", prefix: str):
        """
        Get the weights at the given prefix and apply row-splitting for tensor
        parallism.
        """
        ...


class DefaultWeightsLoader(WeightsLoader):
    """
    Loader that uses tensors as-is with the exception of applying sharding
    and/or concatenation.
    """

    def get_weights_col_packed(
        self,
        weights: "Weights",
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        return weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
        return torch.cat(w, dim=dim)

    def get_weights_row(self, weights: "Weights", prefix: str):
        return weights.get_sharded(f"{prefix}.weight", dim=1)
87
88
89


class Weights:
90
91
92
93
94
95
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
96
        weights_loader: WeightsLoader,
97
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
98
        prefix: Optional[str] = None,
99
    ):
100
101
102
103
104
105
106
107
108
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
109
110
111
        if aliases is None:
            aliases = {}
        self.aliases = aliases
112
113
114
115
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
116
        self.prefix = prefix
117
        self.weights_loader = weights_loader
118
119
120
121
122
123
124
125
126
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

127
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
131
132
133
134
135
136
137
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
138
139
140
141
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
142
        raise RuntimeError(f"weight {tensor_name} does not exist")
143
144

    def _get_slice(self, tensor_name: str):
145
        filename, tensor_name = self.get_filename(tensor_name)
146
147
148
149
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

150
151
152
153
154
155
156
    def _has_tensor(self, tensor_name: str):
        try:
            self.get_filename(tensor_name)
        except Exception:
            return False
        return True

157
158
159
    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
160
    def get_tensor(self, tensor_name: str, to_device=True):
161
        filename, tensor_name = self.get_filename(tensor_name)
162
163
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
164
        # Special case for gptq which shouldn't convert
165
166
167
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
168
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
169
170
        if to_device:
            tensor = tensor.to(device=self.device)
171
172
        return tensor

173
    def get_partial_sharded(self, tensor_name: str, dim: int):
174
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
175
176
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
177
178
179
180
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
181
        block_size = (size + world_size - 1) // world_size
182
183
184
185
186
187
188
189
190
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
191
        # Special case for gptq which shouldn't convert
192
193
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
194
            tensor = tensor.to(dtype=self.dtype)
195
196
        tensor = tensor.to(device=self.device)
        return tensor
197

198
199
200
201
202
203
204
205
206
207
208
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def get_packed_sharded(
        self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
230
231
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
232
233
234
        world_size = self.process_group.size()
        rank = self.process_group.rank()

235
        tensors = []
236
237
238
239
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
240
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
241
242
243
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
244
245
246
247
248
249
250
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
251
            block_offset += block_size
252
253
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)
254

255
256
257
258
259
        # Avoid casting quantizer dtypes.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)

        return tensor
xiaobin's avatar
xiaobin committed
260

261
262
263
264
265
266
267
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
268
            prefix, [num_heads, num_key_value_heads, num_key_value_heads]
269
        )
Nicolas Patry's avatar
Nicolas Patry committed
270

271
272
    def get_weights_col_packed_gate_up(self, prefix: str):
        return self.get_weights_col_packed(prefix, 2)
Nicolas Patry's avatar
Nicolas Patry committed
273

274
    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
xiaobin's avatar
xiaobin committed
275
        """
276
277
278
279
280
        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
281
        """
282
        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
283

284
285
    def get_weights_col(self, prefix: str):
        return self.weights_loader.get_weights_col(self, prefix)
286

287
288
    def get_multi_weights_col(self, prefixes: List[str], dim: int):
        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
OlivierDehaene's avatar
OlivierDehaene committed
289

xiaobin's avatar
xiaobin committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
304
        return tensor
305

306
307
    def get_weights_row(self, prefix: str):
        return self.weights_loader.get_weights_row(self, prefix)
308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
    @contextmanager
    def use_loader(self, weights_loader: WeightsLoader):
        """
        This method is a context manager that can be used to use `Weights` with
        a different loader for the duration of the context.
        """

        old_loader = self.weights_loader
        self.weights_loader = weights_loader
        try:
            yield
        finally:
            self.weights_loader = old_loader

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349

def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks