weights.py 14.3 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from pathlib import Path
3
from typing import List, Dict, Optional, Tuple
4
from safetensors import safe_open, SafetensorError
5
import torch
6
from loguru import logger
7
8
from huggingface_hub import hf_hub_download
import json
9
10
11


class Weights:
12
13
14
15
16
17
18
19
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
    ):
20
21
22
23
24
25
26
27
28
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
29
30
31
        if aliases is None:
            aliases = {}
        self.aliases = aliases
32
33
34
35
36
37
38
39
40
41
42
43
44
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

45
    def get_filename(self, tensor_name: str) -> (str, str):
46
47
        filename = self.routing.get(tensor_name, None)
        if filename is None:
48
49
50
51
52
            aliases = self.aliases.get(tensor_name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
53
            raise RuntimeError(f"weight {tensor_name} does not exist")
54
        return str(filename), tensor_name
55
56

    def _get_slice(self, tensor_name: str):
57
        filename, tensor_name = self.get_filename(tensor_name)
58
59
60
61
62
63
64
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
65
    def get_tensor(self, tensor_name: str, to_device=True):
66
        filename, tensor_name = self.get_filename(tensor_name)
67
68
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
69
70
71
72
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype not in [torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
73
74
        if to_device:
            tensor = tensor.to(device=self.device)
75
76
        return tensor

77
    def get_partial_sharded(self, tensor_name: str, dim: int):
78
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
79
80
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
        block_size = size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
95
96
97
98
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype != torch.int32:
            tensor = tensor.to(dtype=self.dtype)
99
100
        tensor = tensor.to(device=self.device)
        return tensor
101

102
103
104
105
106
107
108
109
110
111
112
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

xiaobin's avatar
xiaobin committed
113
114
115
116
117
118
119
120
    def _get_qweight(self, name: str):
        slice_ = self._get_slice(name)
        total_size = slice_.get_shape()[1]
        assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
        single_size = total_size // 3
        world_size = self.process_group.size()
        rank = self.process_group.rank()

OlivierDehaene's avatar
OlivierDehaene committed
121
122
123
        assert (
            single_size % world_size == 0
        ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
xiaobin's avatar
xiaobin committed
124
125
126
127
        block_size = single_size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        q = slice_[:, start:stop]
OlivierDehaene's avatar
OlivierDehaene committed
128
129
130
        k = slice_[:, start + single_size : stop + single_size]
        v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
        weight = torch.cat([q, k, v], dim=1)
xiaobin's avatar
xiaobin committed
131
132
133
134
135
136
137
138
        weight = weight.to(device=self.device)
        return weight

    def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
        already alternating Q,K,V within the main tensor
        """
139
        if quantize in ["gptq", "awq"]:
xiaobin's avatar
xiaobin committed
140
            try:
OlivierDehaene's avatar
OlivierDehaene committed
141
                qweight = self._get_qweight(f"{prefix}.qweight")
xiaobin's avatar
xiaobin committed
142
143
            except RuntimeError:
                raise RuntimeError(
144
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
145
146
                )

OlivierDehaene's avatar
OlivierDehaene committed
147
148
            qzeros = self._get_qweight(f"{prefix}.qzeros")
            scales = self._get_qweight(f"{prefix}.scales")
xiaobin's avatar
xiaobin committed
149
            scales = scales.to(dtype=self.dtype)
150
151
152
153
            if quantize == "gptq":
                g_idx = self.get_tensor(f"{prefix}.g_idx")
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
154
155
156
157

            bits, groupsize = self._get_gptq_params()
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
158
            slice_ = self._get_slice(f"{prefix}.weight")
xiaobin's avatar
xiaobin committed
159
160
161
162
163
164
            total_size = slice_.get_shape()[0]
            assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
            single_size = total_size // 3
            world_size = self.process_group.size()
            rank = self.process_group.rank()

OlivierDehaene's avatar
OlivierDehaene committed
165
166
167
            assert (
                single_size % world_size == 0
            ), f"Prepacked qkv cannot be sharded across {world_size} shards"
xiaobin's avatar
xiaobin committed
168
169
170
171
            block_size = single_size // world_size
            start = rank * block_size
            stop = (rank + 1) * block_size
            q = slice_[start:stop]
OlivierDehaene's avatar
OlivierDehaene committed
172
173
174
            k = slice_[start + single_size : stop + single_size]
            v = slice_[start + 2 * single_size : stop + 2 * single_size]
            weight = torch.cat([q, k, v], dim=0)
xiaobin's avatar
xiaobin committed
175
176
177
178
            weight = weight.to(device=self.device)
            weight = weight.to(dtype=self.dtype)
        return weight

179
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
180
        if quantize in ["gptq", "awq"]:
181
            try:
182
183
184
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
185
            except RuntimeError:
186
                raise RuntimeError(
187
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
188
189
190
191
192
193
194
195
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
196
197
198
199
200
201
202
203

            if quantize == "gptq":
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
            else:
                g_idx = None
204

205
            bits, groupsize = self._get_gptq_params()
206
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
207
208
209
210
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
211

xiaobin's avatar
xiaobin committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
226
        return tensor
227
228
229

    def get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "gptq":
230
            use_exllama = True
231
            bits, groupsize = self._get_gptq_params()
232
233
234
235
236
237
238

            if bits != 4:
                use_exllama = False

            if self.process_group.size() > 1:
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                if g_idx is not None:
239
240
241
242
243
244
245
246
247
248
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
                                [i // groupsize for i in range(g_idx.shape[0])],
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
249
250
251
252
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

253
254
255
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
256
257
258
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )
259

260
            from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
261

262
            if use_exllama:
263
264
265
266
267
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
                        logger.warning(
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
                        )
268
269
270
271
272
273
274
275
276
277
278
                    use_exllama = False
                else:
                    logger.info("Using exllama kernels")

            if use_exllama:
                if groupsize >= 0:
                    # Exllama reorders the weights in advance and the activations on the fly, thus
                    # the scales and zero-points do not need to be reordered.
                    qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)
                else:
279
280
                    qzeros = self.get_tensor(f"{prefix}.qzeros")
                    scales = self.get_tensor(f"{prefix}.scales")
281
282
283
284
285
286
287
288
289
290
291
292
293

                # For tp > 1, at this point we know we do not use act-order
                if self.process_group.size() == 1:
                    g_idx = self.get_tensor(f"{prefix}.g_idx")
                else:
                    g_idx = None
            else:
                # The triton kernel reorders the scales/zero points instead of the weight/activation.
                # Thus, each rank needs the full qzeros/scales.
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
        elif quantize == "awq":
            bits, groupsize = self._get_gptq_params()

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
309

310
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
311
312
313
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
314

315
    def _get_gptq_params(self) -> Tuple[int, int]:
316
317
318
319
320
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
        except (SafetensorError, RuntimeError) as e:
            try:
321
322
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
323
324
325
326
            except Exception:
                raise e

        return bits, groupsize
327
328

    def _set_gptq_params(self, model_id):
329
        filename = "config.json"
330
        try:
331
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
332
333
334
                filename = os.path.join(model_id, filename)
            else:
                filename = hf_hub_download(model_id, filename=filename)
335
336
            with open(filename, "r") as f:
                data = json.load(f)
337
338
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
339
        except Exception:
340
341
342
343
344
345
346
347
348
349
350
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
                    filename = hf_hub_download(model_id, filename=filename)
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
            except Exception:
351
352
353
354
355
356
357
358
359
360
361
362
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
                        filename = hf_hub_download(model_id, filename=filename)
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
                except Exception:
                    pass