"tests/vscode:/vscode.git/clone" did not exist on "7202115ebbc64edb4000bdd7eed8f276a556304e"
flash_santacoder.py 16.3 KB
Newer Older
1
2
3
4
5
import torch
import torch.distributed

from accelerate import init_empty_weights
from opentelemetry import trace
6
from safetensors import safe_open
7
from pathlib import Path
8
from transformers import AutoTokenizer, GPT2Config
9
10
11
12
from typing import Optional, List

from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
13
    FlashSantacoderForCausalLM,
14
15
16
    TensorParallelRowLinear,
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
17
18
)
from text_generation_server.utils import (
19
    initialize_torch_distributed,
20
21
22
23
24
25
26
27
28
29
    weight_files,
    download_weights,
    weight_hub_files,
    LocalEntryNotFoundError,
)

tracer = trace.get_tracer(__name__)


class FlashSantacoder(FlashCausalLM):
30
31
32
33
34
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
35
        trust_remote_code: bool = False,
36
    ):
37
38
        if torch.cuda.is_available():
            device = torch.device("cuda")
39
            dtype = torch.float16
40
41
42
43
        else:
            raise NotImplementedError("FlashSantacoder is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
44
45
46
47
48
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
49
50
        )

51
        config = GPT2Config.from_pretrained(
52
53
            model_id,
            revision=revision,
54
55
56
        )

        # We do not use from_pretrained as we modified the model internal module layout
57
        filenames = weight_files(model_id, revision, ".safetensors")
58
59
60
61
62

        with init_empty_weights():
            model = FlashSantacoderForCausalLM(config)

        self.load_weights(
63
64
65
66
67
68
            model,
            filenames,
            quantize,
            device,
            dtype,
            config.architectures[0].startswith("GPT2"),
69
70
71
        )

        super(FlashCausalLM, self).__init__(
72
            model=model.to(device),
73
74
75
76
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
77
78
79
80
        )

    @staticmethod
    def load_weights(
81
82
        model: FlashSantacoderForCausalLM,
        filenames: List[Path],
83
        quantize: Optional[str],
84
85
        device: torch.device,
        dtype: torch.dtype,
86
        transpose: bool,
87
88
    ):
        for filename in filenames:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            with safe_open(
                    filename, framework="pt", device=str(device) if quantize is None else "cpu"
            ) as f:
                for key in f.keys():
                    value = f.get_tensor(key)
                    value = value.to(device if quantize is None else "cpu").to(dtype)

                    layer_name = ".".join(key.split(".")[:4])

                    # Fused qkv
                    if "q_attn.weight" in key or "kv_attn.weight" in key:
                        final_key = layer_name + ".c_attn.weight"
                    elif "q_attn.bias" in key or "kv_attn.bias" in key:
                        final_key = layer_name + ".c_attn.bias"

                    else:
                        final_key = key

                    module_name, param_name = final_key.rsplit(".", 1)
                    module = model.get_submodule(module_name)

                    try:
                        current_parameter_tensor = module._parameters[param_name]
                    except KeyError:
                        current_parameter_tensor = None

                    if current_parameter_tensor is not None:
                        if transpose and (
                            "c_fc.weight" in key
                            or "c_proj.weight" in key
                            or "q_attn.weight" in key
                            or "kv_attn.weight" in key
                            or "c_attn.weight" in key
                        ):
                            # Tranpose as we use nn.Linear instead of Conv1D
                            value = value.T

                        if current_parameter_tensor.device == torch.device("meta"):
                            # Init qkv
                            if "c_attn.weight" in final_key:
                                module._parameters[param_name] = value.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2),
                                        value.shape[1],
                                    )
135
                                )
136
137
138
139
140
141
                            elif "c_attn.bias" in final_key:
                                module._parameters[param_name] = value.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2)
                                    )
142
                                )
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
                        # Copy to correct slice
                        if "q_attn.weight" in key:
                            module._parameters[param_name][: value.shape[0]] = value
                        elif "q_attn.bias" in key:
                            module._parameters[param_name][: value.shape[0]] = value
                        elif "kv_attn.weight" in key:
                            module._parameters[param_name][
                                model.transformer.head_size * model.transformer.num_heads :
                            ] = value
                        elif "kv_attn.bias" in key:
                            module._parameters[param_name][
                                model.transformer.head_size * model.transformer.num_heads :
                            ] = value
                        else:
                            if current_parameter_tensor.shape != value.shape:
                                raise ValueError(
                                    f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
                                )
                            module._parameters[param_name] = value
163
                    else:
164
                        module._buffers[param_name] = value
165

166
167
168
169
                    del value

        if model.lm_head.weight.device == torch.device("meta"):
            model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
170

171
        torch.cuda.empty_cache()
172
        model.post_load_weights(quantize)
173

174
175
176
177
178
179
180
181
182
        uninitialized_parameters = []
        for n, p in model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model : {uninitialized_parameters}"
            )

183
184
185
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
186
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
187
        )
188
189
190
191


class FlashSantacoderSharded(FlashSantacoder):
    def __init__(
192
193
194
195
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
196
        trust_remote_code: bool = False,
197
    ):
198
        self.process_group, rank, world_size = initialize_torch_distributed()
199
        if torch.cuda.is_available():
200
            device = torch.device(f"cuda:{rank}")
201
            dtype = torch.float16
202
203
204
205
        else:
            raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
206
207
208
209
210
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        )

        config = GPT2Config.from_pretrained(
            model_id,
            revision=revision,
        )

        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")

        with init_empty_weights():
            model = FlashSantacoderForCausalLM(config, self.process_group)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
228
            quantize=quantize,
229
230
            device=device,
            dtype=dtype,
231
232
            rank=rank,
            world_size=world_size,
233
234
235
236
            transpose=config.architectures[0].startswith("GPT2"),
        )
        torch.distributed.barrier(group=self.process_group)
        super(FlashCausalLM, self).__init__(
237
            model=model.to(device),
238
            tokenizer=tokenizer,
239
240
            requires_padding=False,
            dtype=dtype,
241
            device=device,
242
243
            rank=rank,
            world_size=world_size,
244
245
246
247
248
249
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
250
        quantize: Optional[str],
251
252
253
254
255
256
257
        device: torch.device,
        dtype: torch.dtype,
        rank: int,
        world_size: int,
        transpose: bool,
    ):
        for file in filenames:
258
            with safe_open(
259
                file, framework="pt", device=str(device) if quantize is None else "cpu"
260
            ) as f:
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                for key in f.keys():
                    slice_ = f.get_slice(key)

                    layer_name = ".".join(key.split(".")[:4])

                    # Fused qkv
                    if "q_attn.weight" in key or "kv_attn.weight" in key:
                        final_key = layer_name + ".c_attn.weight"
                    elif "q_attn.bias" in key or "kv_attn.bias" in key:
                        final_key = layer_name + ".c_attn.bias"
                    else:
                        final_key = key

                    module_name, param_name = final_key.rsplit(".", 1)
                    module = model.get_submodule(module_name)

                    if isinstance(module, TensorParallelColumnLinear):
                        dim = 1 if transpose and "weight" in param_name else 0
                        size = slice_.get_shape()[dim]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = (
                            slice_[start:stop] if dim == 0 else slice_[:, start:stop]
                        )
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            dim = 0 if transpose else 1
                            size = slice_.get_shape()[dim]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = (
                                slice_[start:stop]
                                if dim == 0
                                else slice_[:, start:stop]
                            )
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    elif key == "lm_head.weight" and model.transformer.tp_embeddings:
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        try:
                            tensor = slice_[:]
                        except:
                            tensor = f.get_tensor(key)

                    tensor = tensor.contiguous().to(dtype)

                    try:
                        current_parameter_tensor = module._parameters[param_name]
                    except KeyError:
                        current_parameter_tensor = None

                    if current_parameter_tensor is not None:
                        if transpose and (
                            "c_fc.weight" in key
                            or "c_proj.weight" in key
                            or "q_attn.weight" in key
                            or "kv_attn.weight" in key
                            or "c_attn.weight" in key
                        ):
                            # Tranpose as we use nn.Linear instead of Conv1D
                            tensor = tensor.T

                        if current_parameter_tensor.device == torch.device("meta"):
                            # Init qkv
                            if "c_attn.weight" in final_key:
                                module._parameters[param_name] = tensor.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2),
                                        tensor.shape[1],
                                    )
                                )
                            elif "c_attn.bias" in final_key:
                                module._parameters[param_name] = tensor.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2)
                                    )
                                )

                        # Copy to correct slice
                        if "q_attn" in key:
                            size = tensor.shape[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = tensor[start:stop]
                            module._parameters[param_name][: tensor.shape[0]] = tensor
                        elif "kv_attn.weight" in key:
                            module._parameters[param_name][
                                model.transformer.head_size
                                * model.transformer.num_heads :
                            ] = tensor
                        elif "kv_attn.bias" in key:
                            module._parameters[param_name][
                                model.transformer.head_size
                                * model.transformer.num_heads :
                            ] = tensor
                        elif "c_attn" in key:
                            # Slice q_tensor by shard
                            q_tensor = tensor[: -2 * model.transformer.head_size]
                            block_size = q_tensor.shape[0] // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            q_tensor = q_tensor[start:stop]

                            module._parameters[param_name][
                                : q_tensor.shape[0]
                            ] = q_tensor

                            # Kv tensor is copied for every shard
                            kv_tensor = tensor[-2 * model.transformer.head_size :]
                            module._parameters[param_name][
                                q_tensor.shape[0] :
                            ] = kv_tensor
                        else:
                            if current_parameter_tensor.shape != tensor.shape:
                                raise ValueError(
                                    f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
                                )

                            module._parameters[param_name] = tensor
                    else:
                        module._buffers[param_name] = tensor
401

402
403
404
        if model.lm_head.weight.device == torch.device("meta"):
            model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)

405
        torch.cuda.empty_cache()
406
        model.post_load_weights(quantize)