flash_santacoder.py 15.4 KB
Newer Older
1
2
3
4
5
import torch
import torch.distributed

from accelerate import init_empty_weights
from opentelemetry import trace
6
from safetensors import safe_open
7
from pathlib import Path
8
from transformers import AutoTokenizer, GPT2Config
9
10
11
12
from typing import Optional, List

from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
13
    FlashSantacoderForCausalLM,
14
15
16
    TensorParallelRowLinear,
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
17
18
)
from text_generation_server.utils import (
19
    initialize_torch_distributed,
20
21
22
23
24
25
26
27
28
29
30
    weight_files,
    download_weights,
    weight_hub_files,
    LocalEntryNotFoundError,
)

tracer = trace.get_tracer(__name__)


class FlashSantacoder(FlashCausalLM):
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
31
        self.past_pad = None
32
33
34
35
36
37
38
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashSantacoder is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
39
            model_id, revision=revision, padding_side="left", truncation_side="left"
40
41
        )

42
        config = GPT2Config.from_pretrained(
43
44
            model_id,
            revision=revision,
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        )

        # We do not use from_pretrained as we modified the model internal module layout
        try:
            filenames = weight_files(model_id, revision, ".bin")
        # Local files not found
        except LocalEntryNotFoundError:
            hub_files = weight_hub_files(model_id, revision, ".bin")
            filenames = download_weights(hub_files, model_id, revision)

        with init_empty_weights():
            model = FlashSantacoderForCausalLM(config)

        self.load_weights(
59
60
61
62
63
64
            model,
            filenames,
            quantize,
            device,
            dtype,
            config.architectures[0].startswith("GPT2"),
65
        )
66
        self.model = model.eval().to(device)
67
68

        super(FlashCausalLM, self).__init__(
69
70
71
72
73
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
            decode_buffer=1,
74
75
76
77
        )

    @staticmethod
    def load_weights(
78
79
        model: FlashSantacoderForCausalLM,
        filenames: List[Path],
80
        quantize: bool,
81
82
        device: torch.device,
        dtype: torch.dtype,
83
        transpose: bool,
84
85
86
87
    ):
        for filename in filenames:
            state_dict = torch.load(filename, map_location="cpu")
            for key, value in state_dict.items():
88
                value = value.to(device if not quantize else "cpu").to(dtype)
89

90
91
92
93
                layer_name = ".".join(key.split(".")[:4])

                # Fused qkv
                if "q_attn.weight" in key or "kv_attn.weight" in key:
94
                    final_key = layer_name + ".c_attn.weight"
95
                elif "q_attn.bias" in key or "kv_attn.bias" in key:
96
                    final_key = layer_name + ".c_attn.bias"
97
98
99
100
101
102
103
104
105
106
107
108
109

                else:
                    final_key = key

                module_name, param_name = final_key.rsplit(".", 1)
                module = model.get_submodule(module_name)

                try:
                    current_parameter_tensor = module._parameters[param_name]
                except KeyError:
                    current_parameter_tensor = None

                if current_parameter_tensor is not None:
110
                    if transpose and (
111
112
113
114
                        "c_fc.weight" in key
                        or "c_proj.weight" in key
                        or "q_attn.weight" in key
                        or "kv_attn.weight" in key
115
                        or "c_attn.weight" in key
116
                    ):
117
118
119
120
121
                        # Tranpose as we use nn.Linear instead of Conv1D
                        value = value.T

                    if current_parameter_tensor.device == torch.device("meta"):
                        # Init qkv
122
                        if "c_attn.weight" in final_key:
123
                            module._parameters[param_name] = value.new_empty(
124
125
126
127
128
                                (
                                    model.transformer.head_size
                                    * (model.transformer.num_heads + 2),
                                    value.shape[1],
                                )
129
                            )
130
                        elif "c_attn.bias" in final_key:
131
                            module._parameters[param_name] = value.new_empty(
132
133
134
135
                                (
                                    model.transformer.head_size
                                    * (model.transformer.num_heads + 2)
                                )
136
137
138
139
140
141
142
143
144
                            )

                    # Copy to correct slice
                    if "q_attn.weight" in key:
                        module._parameters[param_name][: value.shape[0]] = value
                    elif "q_attn.bias" in key:
                        module._parameters[param_name][: value.shape[0]] = value
                    elif "kv_attn.weight" in key:
                        module._parameters[param_name][
145
                            model.transformer.head_size * model.transformer.num_heads :
146
147
148
                        ] = value
                    elif "kv_attn.bias" in key:
                        module._parameters[param_name][
149
                            model.transformer.head_size * model.transformer.num_heads :
150
151
152
153
154
155
156
157
158
159
                        ] = value
                    else:
                        if current_parameter_tensor.shape != value.shape:
                            raise ValueError(
                                f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
                            )
                        module._parameters[param_name] = value
                else:
                    module._buffers[param_name] = value

160
161
                del value

162
        torch.cuda.empty_cache()
163
        model.post_load_weights(quantize)
164
165
166
167

    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
168
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
169
        )
170
171
172
173
174
175


class FlashSantacoderSharded(FlashSantacoder):
    def __init__(
        self, model_id: str, revision: Optional[str] = None, quantize: bool = False
    ):
176
        self.past_pad = None
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{self.rank}")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
            model_id, revision=revision, padding_side="left", truncation_side="left"
        )

        config = GPT2Config.from_pretrained(
            model_id,
            revision=revision,
        )

        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")

        with init_empty_weights():
            model = FlashSantacoderForCausalLM(config, self.process_group)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
204
            quantize=quantize,
205
206
207
208
209
210
            device=device,
            dtype=dtype,
            rank=self.rank,
            world_size=self.world_size,
            transpose=config.architectures[0].startswith("GPT2"),
        )
211
        self.model = model.eval().to(device)
212
213
214
        torch.distributed.barrier(group=self.process_group)
        super(FlashCausalLM, self).__init__(
            tokenizer=tokenizer,
215
216
            requires_padding=False,
            dtype=dtype,
217
218
219
220
221
222
223
            device=device,
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
224
        quantize: bool,
225
226
227
228
229
230
231
        device: torch.device,
        dtype: torch.dtype,
        rank: int,
        world_size: int,
        transpose: bool,
    ):
        for file in filenames:
232
233
234
            with safe_open(
                file, framework="pt", device=str(device) if not quantize else "cpu"
            ) as f:
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                for key in f.keys():
                    slice_ = f.get_slice(key)

                    layer_name = ".".join(key.split(".")[:4])

                    # Fused qkv
                    if "q_attn.weight" in key or "kv_attn.weight" in key:
                        final_key = layer_name + ".c_attn.weight"
                    elif "q_attn.bias" in key or "kv_attn.bias" in key:
                        final_key = layer_name + ".c_attn.bias"
                    else:
                        final_key = key

                    module_name, param_name = final_key.rsplit(".", 1)
                    module = model.get_submodule(module_name)

                    if isinstance(module, TensorParallelColumnLinear):
                        dim = 1 if transpose and "weight" in param_name else 0
                        size = slice_.get_shape()[dim]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = (
                            slice_[start:stop] if dim == 0 else slice_[:, start:stop]
                        )
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            dim = 0 if transpose else 1
                            size = slice_.get_shape()[dim]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = (
                                slice_[start:stop]
                                if dim == 0
                                else slice_[:, start:stop]
                            )
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    elif key == "lm_head.weight" and model.transformer.tp_embeddings:
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        try:
                            tensor = slice_[:]
                        except:
                            tensor = f.get_tensor(key)

                    tensor = tensor.contiguous().to(dtype)

                    try:
                        current_parameter_tensor = module._parameters[param_name]
                    except KeyError:
                        current_parameter_tensor = None

                    if current_parameter_tensor is not None:
                        if transpose and (
                            "c_fc.weight" in key
                            or "c_proj.weight" in key
                            or "q_attn.weight" in key
                            or "kv_attn.weight" in key
                            or "c_attn.weight" in key
                        ):
                            # Tranpose as we use nn.Linear instead of Conv1D
                            tensor = tensor.T

                        if current_parameter_tensor.device == torch.device("meta"):
                            # Init qkv
                            if "c_attn.weight" in final_key:
                                module._parameters[param_name] = tensor.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2),
                                        tensor.shape[1],
                                    )
                                )
                            elif "c_attn.bias" in final_key:
                                module._parameters[param_name] = tensor.new_empty(
                                    (
                                        model.transformer.head_size
                                        * (model.transformer.num_heads + 2)
                                    )
                                )

                        # Copy to correct slice
                        if "q_attn" in key:
                            size = tensor.shape[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = tensor[start:stop]
                            module._parameters[param_name][: tensor.shape[0]] = tensor
                        elif "kv_attn.weight" in key:
                            module._parameters[param_name][
                                model.transformer.head_size
                                * model.transformer.num_heads :
                            ] = tensor
                        elif "kv_attn.bias" in key:
                            module._parameters[param_name][
                                model.transformer.head_size
                                * model.transformer.num_heads :
                            ] = tensor
                        elif "c_attn" in key:
                            # Slice q_tensor by shard
                            q_tensor = tensor[: -2 * model.transformer.head_size]
                            block_size = q_tensor.shape[0] // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            q_tensor = q_tensor[start:stop]

                            module._parameters[param_name][
                                : q_tensor.shape[0]
                            ] = q_tensor

                            # Kv tensor is copied for every shard
                            kv_tensor = tensor[-2 * model.transformer.head_size :]
                            module._parameters[param_name][
                                q_tensor.shape[0] :
                            ] = kv_tensor
                        else:
                            if current_parameter_tensor.shape != tensor.shape:
                                raise ValueError(
                                    f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
                                )

                            module._parameters[param_name] = tensor
                    else:
                        module._buffers[param_name] = tensor
        torch.cuda.empty_cache()
376
        model.post_load_weights(quantize)