galactica.py 14.3 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type, Tuple
6
7
8

from accelerate import init_empty_weights
from safetensors import safe_open
9
10
11
12
13
14
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
15
16
17
18
19
20
from transformers.models.opt.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

21
22
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
23
24
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
25
from text_generation_server.utils import (
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
89
90
91
92
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
93
    ) -> "GalacticaCausalLMBatch":
94
95
96
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
97
98
        offsets = []
        token_offsets = []
99
        requests_idx_mapping = {}
100
101

        # Parse batch
102
        max_truncation = 0
103
        padding_right_offset = 0
104
        max_decode_tokens = 0
105
106
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
107
108
            # Add escape_custom_split_sequence to the CausalLMBatch logic
            inputs.append(escape_custom_split_sequence(r.inputs))
109
110
            offsets.append(None)
            token_offsets.append(None)
111
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
112
113
114
115
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
116
            max_truncation = max(max_truncation, r.truncate)
117
            max_decode_tokens += stopping_criteria.max_new_tokens
118
119
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
120
121
122
            )

        tokenized_inputs = tokenizer(
123
124
125
126
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
127
128
            truncation=True,
            max_length=max_truncation,
129
        ).to(device)
130
131
132
133

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

134
135
136
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
137
            (pb.size, max_input_length + padding_right_offset)
138
139
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
140
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
141

142
143
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
144
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
145

146
147
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

148
149
150
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
151
            requests_idx_mapping=requests_idx_mapping,
152
153
            input_ids=input_ids,
            attention_mask=attention_mask,
154
            position_ids=position_ids,
155
            past_key_values=None,
156
157
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
158
159
            offsets=offsets,
            token_offsets=token_offsets,
160
161
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
162
            max_input_length=max_input_length.item(),
163
            padding_right_offset=padding_right_offset,
164
            max_tokens=max_tokens,
165
166
167
        )


168
class Galactica(OPT):
169
170
171
172
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch

173
174
175
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
176
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
177
178
        )

179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Overwrite forward to ignore position_ids"""

        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values

193
194

class GalacticaSharded(Galactica):
195
    def __init__(
196
197
198
199
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
200
    ):
201
202
        self.process_group, rank, world_size = initialize_torch_distributed()
        self.master = rank == 0
203
        if torch.cuda.is_available():
204
            device = torch.device(f"cuda:{rank}")
205
            dtype = torch.float16
206
207
208
209
        else:
            device = torch.device("cpu")
            dtype = torch.float32

210
        tokenizer = AutoTokenizer.from_pretrained(
211
            model_id, revision=revision, padding_side="left", truncation_side="left"
212
        )
213

214
        config = AutoConfig.from_pretrained(
215
            model_id, revision=revision, tp_parallel=True
216
        )
217
218
219
        tokenizer.pad_token_id = config.pad_token_id

        torch.distributed.barrier(group=self.process_group)
220
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
221
222
223
224
225
226
227
228
229
230

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
            device=device,
231
            dtype=dtype,
232
233
            rank=rank,
            world_size=world_size,
234
        )
235
        self.model = model.eval()
236
237
238
        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
239
240
            requires_padding=True,
            dtype=dtype,
241
            device=device,
242
243
            rank=rank,
            world_size=world_size,
244
245
246
247
248
249
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
250
        quantize: Optional[str],
251
        device: torch.device,
252
        dtype: torch.dtype,
253
254
255
256
257
258
259
260
261
262
263
264
265
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
                file, framework="pt", device=str(device) if not quantize else "cpu"
            ) as f:
                for name in f.keys():
                    if name == "lm_head.weight":
                        continue

                    module_name, param_name = name.rsplit(".", 1)
266
                    module = model.get_submodule(module_name)
267
268
269
270
271
                    current_tensor = parameters[name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
272
273
274
275
276
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

303
                    tensor = tensor.contiguous().to(dtype)
304

305
                    if quantize == "bitsandbytes":
306
307
308
309
310
311
312
313
314
315
316
317
318
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
                        ):
                            tensor = Int8Params(
319
                                tensor,
320
321
322
323
324
325
326
327
328
329
330
331
332
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

333
                            def replace_linear(state):
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

350
                                    return out
351
352
353

                                return linear

354
                            module.linear = replace_linear(state)
355
356
357
358
359
                        elif quantize == "gptq":
                            raise NotImplementedError(
                                "`gptq` is not implemented for now"
                            )
                        elif quantize is None:
360
                            tensor = tensor.to(device)
361
362
                        else:
                            raise ValueError(f"Unexpected quantize `{quantize}`")
363
364
365
366
367

                    module._parameters[param_name] = tensor
                    if name == "model.decoder.embed_tokens.weight":
                        model.lm_head._parameters["weight"] = tensor

368
369
370
371
372
373
374
375
376
        uninitialized_parameters = []
        for n, p in model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model: {uninitialized_parameters}"
            )

377
378
379
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
380
381
382
383
384
385
386
387
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
388
389
390
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
391
392

        return logits, outputs.past_key_values