galactica.py 14.3 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type, Tuple
6
7
8

from accelerate import init_empty_weights
from safetensors import safe_open
9
10
11
12
13
14
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
15
16
17
18
19
20
from transformers.models.opt.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

21
22
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
23
24
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
25
from text_generation_server.utils import (
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
89
90
91
92
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
93
    ) -> "GalacticaCausalLMBatch":
94
95
96
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
97
98
        prefix_offsets = []
        read_offsets = []
99
        requests_idx_mapping = {}
100
101

        # Parse batch
102
        max_truncation = 0
103
        padding_right_offset = 0
104
        max_decode_tokens = 0
105
106
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
107
108
            # Add escape_custom_split_sequence to the CausalLMBatch logic
            inputs.append(escape_custom_split_sequence(r.inputs))
109
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
110
111
112
113
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
114
            max_truncation = max(max_truncation, r.truncate)
115
            max_decode_tokens += stopping_criteria.max_new_tokens
116
117
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
118
119
120
            )

        tokenized_inputs = tokenizer(
121
122
123
124
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
125
126
            truncation=True,
            max_length=max_truncation,
127
        ).to(device)
128
129
130
131
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
132
133
134
135

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

136
137
138
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
139
            (pb.size, max_input_length + padding_right_offset)
140
141
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
142
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
143

144
145
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
146
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
147

148
149
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

150
151
152
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
153
            requests_idx_mapping=requests_idx_mapping,
154
155
            input_ids=input_ids,
            attention_mask=attention_mask,
156
            position_ids=position_ids,
157
            past_key_values=None,
158
159
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
160
161
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
162
163
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
164
            max_input_length=max_input_length.item(),
165
            padding_right_offset=padding_right_offset,
166
            max_tokens=max_tokens,
167
168
169
        )


170
class Galactica(OPT):
171
172
173
174
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch

175
176
177
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
178
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
179
180
        )

181
182
183
184
185
186
187
188
189
190
191
192
193
194
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Overwrite forward to ignore position_ids"""

        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values

195
196

class GalacticaSharded(Galactica):
197
    def __init__(
198
199
200
201
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
202
        trust_remote_code: bool = False,
203
    ):
204
        self.process_group, rank, world_size = initialize_torch_distributed()
205
        if torch.cuda.is_available():
206
            device = torch.device(f"cuda:{rank}")
207
            dtype = torch.float16
208
209
210
211
        else:
            device = torch.device("cpu")
            dtype = torch.float32

212
        tokenizer = AutoTokenizer.from_pretrained(
213
214
215
216
217
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
218
        )
219

220
        config = AutoConfig.from_pretrained(
221
222
223
224
            model_id,
            revision=revision,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
225
        )
226
227
228
        tokenizer.pad_token_id = config.pad_token_id

        torch.distributed.barrier(group=self.process_group)
229
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
230
231

        with init_empty_weights():
232
233
234
            model = AutoModelForCausalLM.from_config(
                config, trust_remote_code=trust_remote_code
            )
235
236
237
238
239
240
241

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
            device=device,
242
            dtype=dtype,
243
244
            rank=rank,
            world_size=world_size,
245
246
247
        )
        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
248
            model=model,
249
            tokenizer=tokenizer,
250
251
            requires_padding=True,
            dtype=dtype,
252
            device=device,
253
254
            rank=rank,
            world_size=world_size,
255
256
257
258
259
260
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
261
        quantize: Optional[str],
262
        device: torch.device,
263
        dtype: torch.dtype,
264
265
266
267
268
269
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
270
                file, framework="pt", device=str(device) if quantize is None else "cpu"
271
272
273
274
275
276
            ) as f:
                for name in f.keys():
                    if name == "lm_head.weight":
                        continue

                    module_name, param_name = name.rsplit(".", 1)
277
                    module = model.get_submodule(module_name)
278
279
280
281
282
                    current_tensor = parameters[name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
283
284
285
286
287
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

314
                    tensor = tensor.contiguous().to(dtype)
315

316
                    if quantize == "bitsandbytes":
317
318
319
320
321
322
323
324
325
326
327
328
329
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
                        ):
                            tensor = Int8Params(
330
                                tensor,
331
332
333
334
335
336
337
338
339
340
341
342
343
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

344
                            def replace_linear(state):
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

361
                                    return out
362
363
364

                                return linear

365
                            module.linear = replace_linear(state)
366
367
368
369
370
                        elif quantize == "gptq":
                            raise NotImplementedError(
                                "`gptq` is not implemented for now"
                            )
                        elif quantize is None:
371
                            tensor = tensor.to(device)
372
373
                        else:
                            raise ValueError(f"Unexpected quantize `{quantize}`")
374
375
376
377
378

                    module._parameters[param_name] = tensor
                    if name == "model.decoder.embed_tokens.weight":
                        model.lm_head._parameters["weight"] = tensor

379
380
381
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
382
383
384
385
386
387
388
389
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
390
391
392
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
393
394

        return logits, outputs.past_key_values