"docs/vscode:/vscode.git/clone" did not exist on "fed4c6946acd476ab94cad85a1210900a3ae6076"
gpt_neox.py 9.3 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import List, Optional
5
6
7
8
9
10
11
12
13
14
15
16
17
18

from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)
from transformers.models.gpt_neox.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

19
20
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
21
22
23
24
25
26
27
28
29
30
31
32
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


33
class GPTNeoxSharded(CausalLM):
34
    def __init__(
35
        self, model_id: str, revision: Optional[str] = None, quantize: bool = False
36
    ):
37
38
        self.process_group, rank, world_size = initialize_torch_distributed()
        self.master = rank == 0
39
        if torch.cuda.is_available():
40
            device = torch.device(f"cuda:{rank}")
41
            dtype = torch.float16
42
43
44
45
46
        else:
            device = torch.device("cpu")
            dtype = torch.float32

        tokenizer = AutoTokenizer.from_pretrained(
47
            model_id, revision=revision, padding_side="left", truncation_side="left"
48
49
50
51
        )
        tokenizer.pad_token = tokenizer.eos_token

        config = AutoConfig.from_pretrained(
52
            model_id, revision=revision, tp_parallel=True
53
54
55
        )

        torch.distributed.barrier(group=self.process_group)
56
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
57
58
59
60
61
62
63
64
65
66

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
            device=device,
67
            dtype=dtype,
68
69
            rank=rank,
            world_size=world_size,
70
        )
71
        self.model = model.eval()
72
73
74
        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
75
76
            requires_padding=True,
            dtype=dtype,
77
            device=device,
78
79
            rank=rank,
            world_size=world_size,
80
81
82
83
84
85
86
87
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
        quantize: bool,
        device: torch.device,
88
        dtype: torch.dtype,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
                file, framework="pt", device=str(device) if not quantize else "cpu"
            ) as f:
                for name in f.keys():
                    module_name, param_name = name.rsplit(".", 1)
                    module = model.get_submodule(module_name)

                    current_parameter_tensor = parameters.get(name, None)

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
129
                    elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        try:
                            tensor = slice_[:]
                        except:
                            tensor = f.get_tensor(name)

                    if (
                        current_parameter_tensor is not None
                        and current_parameter_tensor.shape != tensor.shape
                    ):
                        raise ValueError(
                            f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
                        )

149
                    tensor = tensor.contiguous().to(dtype)
150
151
152
153
154
155
156
157
158
159

                    if quantize:
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
160
161
162
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                        ):
                            tensor = Int8Params(
                                tensor,
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

                            def replace_linear(state):
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

                                    return out

                                return linear

                            module.linear = replace_linear(state)

                        else:
                            tensor = tensor.to(device)

                    if current_parameter_tensor is not None:
                        module._parameters[param_name] = tensor
                    else:
                        module._buffers[param_name] = tensor

210
211
212
213
214
215
216
217
218
        uninitialized_parameters = []
        for n, p in model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model: {uninitialized_parameters}"
            )

219
220
221
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
222
223
224
225
        if self.model.gpt_neox.tp_embeddings:
            outputs = self.model.forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
226
                position_ids=position_ids,
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                past_key_values=past_key_values,
                use_cache=True,
            )

            # Logits are sharded, so we need to gather them
            logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
            torch.distributed.all_gather(
                logits, outputs.logits, group=self.process_group
            )
            logits = torch.cat(logits, dim=2)

            return logits, outputs.past_key_values
        # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
        else:
            return super(GPTNeoxSharded, self).forward(
                input_ids, attention_mask, position_ids, past_key_values
            )