weights.py 13.1 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from pathlib import Path
3
from typing import List, Dict, Optional, Tuple
4
from safetensors import safe_open, SafetensorError
5
import torch
6
from loguru import logger
7
8
from huggingface_hub import hf_hub_download
import json
9
10
11


class Weights:
12
13
14
15
16
17
18
19
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
    ):
20
21
22
23
24
25
26
27
28
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
29
30
31
        if aliases is None:
            aliases = {}
        self.aliases = aliases
32
33
34
35
36
37
38
39
40
41
42
43
44
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

45
    def get_filename(self, tensor_name: str) -> (str, str):
46
47
        filename = self.routing.get(tensor_name, None)
        if filename is None:
48
49
50
51
52
            aliases = self.aliases.get(tensor_name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
53
            raise RuntimeError(f"weight {tensor_name} does not exist")
54
        return str(filename), tensor_name
55
56

    def _get_slice(self, tensor_name: str):
57
        filename, tensor_name = self.get_filename(tensor_name)
58
59
60
61
62
63
64
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

xiaobin's avatar
xiaobin committed
65
    def get_tensor(self, tensor_name: str, to_device = True):
66
        filename, tensor_name = self.get_filename(tensor_name)
67
68
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
69
70
71
72
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype not in [torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
73
74
        if to_device:
            tensor = tensor.to(device=self.device)
75
76
        return tensor

77
    def get_partial_sharded(self, tensor_name: str, dim: int):
78
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
79
80
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
        block_size = size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
95
96
97
98
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype != torch.int32:
            tensor = tensor.to(dtype=self.dtype)
99
100
        tensor = tensor.to(device=self.device)
        return tensor
101

102
103
104
105
106
107
108
109
110
111
112
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

xiaobin's avatar
xiaobin committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    def _get_qweight(self, name: str):
        slice_ = self._get_slice(name)
        total_size = slice_.get_shape()[1]
        assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
        single_size = total_size // 3
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
        block_size = single_size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        q = slice_[:, start:stop]
        k = slice_[:, start+single_size:stop+single_size]
        v = slice_[:, start+2*single_size:stop+2*single_size]
        weight = torch.cat([q,k,v], dim=1)
        weight = weight.to(device=self.device)
        return weight

    def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
        already alternating Q,K,V within the main tensor
        """
        if quantize == "gptq":
            try:
                qweight = self._get_qweight(f"{prefix}.qweight") 
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            qzeros = self._get_qweight(f"{prefix}.qzeros") 
            scales = self._get_qweight(f"{prefix}.scales") 
            scales = scales.to(dtype=self.dtype)
            g_idx = self.get_tensor(f"{prefix}.g_idx")

            bits, groupsize = self._get_gptq_params()
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
        else:
            slice_ = self._get_slice(f"{prefix}.weight") 
            total_size = slice_.get_shape()[0]
            assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
            single_size = total_size // 3
            world_size = self.process_group.size()
            rank = self.process_group.rank()

            assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
            block_size = single_size // world_size
            start = rank * block_size
            stop = (rank + 1) * block_size
            q = slice_[start:stop]
            k = slice_[start+single_size:stop+single_size]
            v = slice_[start+2*single_size:stop+2*single_size]
            weight = torch.cat([q,k,v], dim=0)
            weight = weight.to(device=self.device)
            weight = weight.to(dtype=self.dtype)
        return weight

173
174
175
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
        if quantize == "gptq":
            try:
176
177
178
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
179
            except RuntimeError:
180
181
182
183
184
185
186
187
188
189
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
190
191
192
193
194
            w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
            for w2 in w[1:]:
                torch.testing.assert_close(w2, w[0])
            g_idx = w[0]

195
            bits, groupsize = self._get_gptq_params()
196
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
197
198
199
200
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
        return weight
xiaobin's avatar
xiaobin committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor 
217
218
219

    def get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "gptq":
220
            use_exllama = True
221
            bits, groupsize = self._get_gptq_params()
222
223
224
225
226
227
228

            if bits != 4:
                use_exllama = False

            if self.process_group.size() > 1:
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                if g_idx is not None:
229
230
231
232
233
234
235
236
237
238
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
                                [i // groupsize for i in range(g_idx.shape[0])],
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
239
240
241
242
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

243
244
245
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
246
247
248
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )
249

250
            from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
251

252
            if use_exllama:
253
254
255
256
257
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
                        logger.warning(
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
                        )
258
259
260
261
262
263
264
265
266
267
268
                    use_exllama = False
                else:
                    logger.info("Using exllama kernels")

            if use_exllama:
                if groupsize >= 0:
                    # Exllama reorders the weights in advance and the activations on the fly, thus
                    # the scales and zero-points do not need to be reordered.
                    qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)
                else:
269
270
                    qzeros = self.get_tensor(f"{prefix}.qzeros")
                    scales = self.get_tensor(f"{prefix}.scales")
271
272
273
274
275
276
277
278
279
280
281
282
283
284

                # For tp > 1, at this point we know we do not use act-order
                if self.process_group.size() == 1:
                    g_idx = self.get_tensor(f"{prefix}.g_idx")
                else:
                    g_idx = None
            else:
                # The triton kernel reorders the scales/zero points instead of the weight/activation.
                # Thus, each rank needs the full qzeros/scales.
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
285
286
287
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
288

289
    def _get_gptq_params(self) -> Tuple[int, int]:
290
291
292
293
294
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
        except (SafetensorError, RuntimeError) as e:
            try:
295
296
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
297
298
299
300
            except Exception:
                raise e

        return bits, groupsize
301
302

    def _set_gptq_params(self, model_id):
303
        filename = "config.json"
304
        try:
305
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
306
307
308
                filename = os.path.join(model_id, filename)
            else:
                filename = hf_hub_download(model_id, filename=filename)
309
310
            with open(filename, "r") as f:
                data = json.load(f)
311
312
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
313
        except Exception:
314
315
316
317
318
319
320
321
322
323
324
325
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
                    filename = hf_hub_download(model_id, filename=filename)
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
            except Exception:
                pass