galactica.py 13.5 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type, Tuple
6
7
8

from accelerate import init_empty_weights
from safetensors import safe_open
9
10
11
12
13
14
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
15
16
17
18
19
20
from transformers.models.opt.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

21
22
23
24
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import (
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
88
89
90
91
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
92
    ) -> "GalacticaCausalLMBatch":
93
94
95
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
96
97
        offsets = []
        token_offsets = []
98
99

        # Parse batch
100
        max_truncation = 0
101
        padding_right_offset = 0
102
103
104
        for r in pb.requests:
            # Add escape_custom_split_sequence to the CausalLMBatch logic
            inputs.append(escape_custom_split_sequence(r.inputs))
105
106
            offsets.append(None)
            token_offsets.append(None)
107
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
108
109
110
111
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
112
            max_truncation = max(max_truncation, r.truncate)
113
114
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
115
116
            )

117
        # Tokenize batch
118
        tokenized_inputs = tokenizer(
119
120
121
122
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
123
124
            truncation=True,
            max_length=max_truncation,
125
        ).to(device)
126
127
128
129

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

130
131
132
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
133
            (pb.size, max_input_length + padding_right_offset)
134
135
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
136
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
137

138
139
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
140
141
142
143
144
        all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
145
146
            input_ids=input_ids,
            attention_mask=attention_mask,
147
            position_ids=position_ids,
148
149
150
            past_key_values=None,
            all_input_ids=all_input_ids,
            input_lengths=input_lengths,
151
152
            offsets=offsets,
            token_offsets=token_offsets,
153
154
155
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
            size=pb.size,
156
            max_input_length=max_input_length,
157
            padding_right_offset=padding_right_offset,
158
159
160
161
162
163
164
165
        )


class Galactica(CausalLM):
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch

166
167
168
169
170
171
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
        )

172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Overwrite forward to ignore position_ids"""

        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values

186
187

class GalacticaSharded(Galactica):
188
    def __init__(
189
        self, model_id: str, revision: Optional[str] = None, quantize: bool = False
190
    ):
191
192
193
194
195
196
197
198
199
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{self.rank}")
            dtype = torch.bfloat16
        else:
            device = torch.device("cpu")
            dtype = torch.float32

200
        tokenizer = AutoTokenizer.from_pretrained(
201
            model_id, revision=revision, padding_side="left", truncation_side="left"
202
        )
203

204
        config = AutoConfig.from_pretrained(
205
            model_id, revision=revision, tp_parallel=True
206
        )
207
208
209
        tokenizer.pad_token_id = config.pad_token_id

        torch.distributed.barrier(group=self.process_group)
210
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
            device=device,
            rank=self.rank,
            world_size=self.world_size,
        )
        self.model = model.eval().to(dtype)
        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
            tokenizer=tokenizer,
            device=device,
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
        quantize: bool,
        device: torch.device,
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
                file, framework="pt", device=str(device) if not quantize else "cpu"
            ) as f:
                for name in f.keys():
                    if name == "lm_head.weight":
                        continue

                    module_name, param_name = name.rsplit(".", 1)
250
                    module = model.get_submodule(module_name)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                    current_tensor = parameters[name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                        else:
                            size = slice_.get_shape()[0]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[start:stop]
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

                    tensor = tensor.contiguous()

                    if quantize:
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
                        ):
                            tensor = Int8Params(
310
                                tensor,
311
312
313
314
315
316
317
318
319
320
321
322
323
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

324
                            def replace_linear(state):
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

341
                                    return out
342
343
344

                                return linear

345
                            module.linear = replace_linear(state)
346
347
348
349
350
351
352
353

                        else:
                            tensor = tensor.to(device)

                    module._parameters[param_name] = tensor
                    if name == "model.decoder.embed_tokens.weight":
                        model.lm_head._parameters["weight"] = tensor

354
355
356
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
357
358
359
360
361
362
363
364
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
365
366
367
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
368
369

        return logits, outputs.past_key_values