weights.py 8.45 KB
Newer Older
1
from pathlib import Path
2
from typing import List, Dict, Optional, Tuple
3
from safetensors import safe_open, SafetensorError
4
import torch
5
from loguru import logger
6
7
8


class Weights:
9
10
11
12
13
14
15
16
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
    ):
17
18
19
20
21
22
23
24
25
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
26
27
28
        if aliases is None:
            aliases = {}
        self.aliases = aliases
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

42
    def get_filename(self, tensor_name: str) -> (str, str):
43
44
        filename = self.routing.get(tensor_name, None)
        if filename is None:
45
46
47
48
49
            aliases = self.aliases.get(tensor_name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
50
            raise RuntimeError(f"weight {tensor_name} does not exist")
51
        return str(filename), tensor_name
52
53

    def _get_slice(self, tensor_name: str):
54
        filename, tensor_name = self.get_filename(tensor_name)
55
56
57
58
59
60
61
62
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

    def get_tensor(self, tensor_name: str):
63
        filename, tensor_name = self.get_filename(tensor_name)
64
65
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
66
67
68
69
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype not in [torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)
70
71
72
        tensor = tensor.to(device=self.device)
        return tensor

73
    def get_partial_sharded(self, tensor_name: str, dim: int):
74
        filename, tensor_name = self.get_filename(tensor_name)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        size = slice_.get_shape()[dim]
        block_size = size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
91
92
93
94
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype != torch.int32:
            tensor = tensor.to(dtype=self.dtype)
95
96
        tensor = tensor.to(device=self.device)
        return tensor
97

98
99
100
101
102
103
104
105
106
107
108
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

109
110
111
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
        if quantize == "gptq":
            try:
112
113
114
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
115
            except RuntimeError:
116
117
118
119
120
121
122
123
124
125
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
126
127
128
129
130
            w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
            for w2 in w[1:]:
                torch.testing.assert_close(w2, w[0])
            g_idx = w[0]

131
132
            bits, groupsize = self._get_gptq_qparams()
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
133
134
135
136
137
138
139
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
        return weight

    def get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "gptq":
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            use_exllama = True
            bits, groupsize = self._get_gptq_qparams()

            if bits != 4:
                use_exllama = False

            if self.process_group.size() > 1:
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                if g_idx is not None:
                    if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

154
155
156
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
            

            from text_generation_server.utils.layers import HAS_EXLLAMA
            if use_exllama:
                if not HAS_EXLLAMA:
                    logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
                    use_exllama = False
                else:
                    logger.info("Using exllama kernels")


            if use_exllama:
                if groupsize >= 0:
                    # Exllama reorders the weights in advance and the activations on the fly, thus
                    # the scales and zero-points do not need to be reordered.
                    qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)
                else:
                    raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported") 
                    # qzeros = self.get_tensor(f"{prefix}.qzeros")
                    # scales = self.get_tensor(f"{prefix}.scales")

                # For tp > 1, at this point we know we do not use act-order
                if self.process_group.size() == 1:
                    g_idx = self.get_tensor(f"{prefix}.g_idx")
                else:
                    g_idx = None
            else:
                # The triton kernel reorders the scales/zero points instead of the weight/activation.
                # Thus, each rank needs the full qzeros/scales.
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
193
194
195
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    def _get_gptq_qparams(self) -> Tuple[int, int]:
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
        except (SafetensorError, RuntimeError) as e:
            try:
                import os

                bits = int(os.getenv("GPTQ_BITS"))
                groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
            except Exception:
                raise e

        return bits, groupsize