weights.py 5.37 KB
Newer Older
1
from pathlib import Path
2
from typing import List, Dict, Optional
3
from safetensors import safe_open
4
import torch
5
6
7


class Weights:
8
    def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
9
10
11
12
13
14
15
16
17
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
18
19
20
        if aliases is None:
            aliases = {}
        self.aliases = aliases
21
22
23
24
25
26
27
28
29
30
31
32
33
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

34
    def get_filename(self, tensor_name: str) -> (str, str):
35
36
        filename = self.routing.get(tensor_name, None)
        if filename is None:
37
38
39
40
41
            aliases = self.aliases.get(tensor_name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
42
            raise RuntimeError(f"weight {tensor_name} does not exist")
43
        return str(filename), tensor_name
44
45

    def _get_slice(self, tensor_name: str):
46
        filename, tensor_name= self.get_filename(tensor_name)
47
48
49
50
51
52
53
54
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

    def get_tensor(self, tensor_name: str):
55
        filename, tensor_name = self.get_filename(tensor_name)
56
57
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
58
59
60
61
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype not in [torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)
62
63
64
65
        tensor = tensor.to(device=self.device)
        return tensor

    def get_sharded(self, tensor_name: str, dim: int):
66
        filename, tensor_name = self.get_filename(tensor_name)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        size = slice_.get_shape()[dim]
        block_size = size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
87
88
89
90
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32
        if tensor.dtype != torch.int32:
            tensor = tensor.to(dtype=self.dtype)
91
92
        tensor = tensor.to(device=self.device)
        return tensor
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
        if quantize == "gptq":
            try:
                qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
            except RuntimeError:
                raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")

            qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
            scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
            w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
            for w2 in w[1:]:
                torch.testing.assert_close(w2, w[0])
            g_idx = w[0]

            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
        return weight

    def get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "gptq":
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
            qzeros = self.get_tensor(f"{prefix}.qzeros")
            scales = self.get_tensor(f"{prefix}.scales")
            g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()

            weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight