"cpu/compat.h" did not exist on "08dda1ad18230e75976ff968d7f5a8675e158e50"
galactica.py 14.4 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type, Tuple
6
7
8

from accelerate import init_empty_weights
from safetensors import safe_open
9
10
11
12
13
14
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
15
16
17
18
19
20
from transformers.models.opt.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

21
22
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
23
24
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
25
from text_generation_server.utils import (
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
89
90
91
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
92
        dtype: torch.dtype,
93
        device: torch.device,
94
    ) -> "GalacticaCausalLMBatch":
95
96
97
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
98
99
        prefix_offsets = []
        read_offsets = []
100
        requests_idx_mapping = {}
101
102

        # Parse batch
103
        max_truncation = 0
104
        padding_right_offset = 0
105
        max_decode_tokens = 0
106
107
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
108
109
            # Add escape_custom_split_sequence to the CausalLMBatch logic
            inputs.append(escape_custom_split_sequence(r.inputs))
110
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
111
112
113
114
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
115
            max_truncation = max(max_truncation, r.truncate)
116
            max_decode_tokens += stopping_criteria.max_new_tokens
117
118
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
119
120
121
            )

        tokenized_inputs = tokenizer(
122
123
124
125
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
126
127
            truncation=True,
            max_length=max_truncation,
128
        ).to(device)
129
130
131
132
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
133
134
135
136

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

137
138
139
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
140
            (pb.size, max_input_length + padding_right_offset)
141
142
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
143
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
144

145
146
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
147
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
148

149
150
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

151
152
153
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
154
            requests_idx_mapping=requests_idx_mapping,
155
156
            input_ids=input_ids,
            attention_mask=attention_mask,
157
            position_ids=position_ids,
158
            past_key_values=None,
159
160
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
161
162
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
163
164
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
165
            max_input_length=max_input_length.item(),
166
            padding_right_offset=padding_right_offset,
167
            max_tokens=max_tokens,
168
169
170
        )


171
class Galactica(OPT):
172
173
174
175
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch

176
177
178
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
179
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
180
181
        )

182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Overwrite forward to ignore position_ids"""

        # Model Forward
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return outputs.logits, outputs.past_key_values

196
197

class GalacticaSharded(Galactica):
198
    def __init__(
199
200
201
202
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
203
        trust_remote_code: bool = False,
204
    ):
205
        self.process_group, rank, world_size = initialize_torch_distributed()
206
        if torch.cuda.is_available():
207
            device = torch.device(f"cuda:{rank}")
208
            dtype = torch.float16
209
210
211
212
        else:
            device = torch.device("cpu")
            dtype = torch.float32

213
        tokenizer = AutoTokenizer.from_pretrained(
214
215
216
217
218
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
219
        )
220

221
        config = AutoConfig.from_pretrained(
222
223
224
225
            model_id,
            revision=revision,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
226
        )
227
228
229
        tokenizer.pad_token_id = config.pad_token_id

        torch.distributed.barrier(group=self.process_group)
230
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
231
232

        with init_empty_weights():
233
234
235
            model = AutoModelForCausalLM.from_config(
                config, trust_remote_code=trust_remote_code
            )
236
237
238
239
240
241
242

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
            device=device,
243
            dtype=dtype,
244
245
            rank=rank,
            world_size=world_size,
246
247
248
        )
        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
249
            model=model,
250
            tokenizer=tokenizer,
251
252
            requires_padding=True,
            dtype=dtype,
253
            device=device,
254
255
            rank=rank,
            world_size=world_size,
256
257
258
259
260
261
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
262
        quantize: Optional[str],
263
        device: torch.device,
264
        dtype: torch.dtype,
265
266
267
268
269
270
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
271
                file, framework="pt", device=str(device) if quantize is None else "cpu"
272
273
274
275
276
277
            ) as f:
                for name in f.keys():
                    if name == "lm_head.weight":
                        continue

                    module_name, param_name = name.rsplit(".", 1)
278
                    module = model.get_submodule(module_name)
279
280
281
282
283
                    current_tensor = parameters[name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
284
285
286
287
288
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

315
                    tensor = tensor.contiguous().to(dtype)
316

317
                    if quantize == "bitsandbytes":
318
319
320
321
322
323
324
325
326
327
328
329
330
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
                        ):
                            tensor = Int8Params(
331
                                tensor,
332
333
334
335
336
337
338
339
340
341
342
343
344
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

345
                            def replace_linear(state):
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

362
                                    return out
363
364
365

                                return linear

366
                            module.linear = replace_linear(state)
367
368
                        else:
                            tensor = tensor.to(device)
OlivierDehaene's avatar
OlivierDehaene committed
369
370
371
372
373
374
                    elif quantize == "gptq":
                        raise NotImplementedError("`gptq` is not implemented for now")
                    elif quantize is None:
                        tensor = tensor.to(device)
                    else:
                        raise ValueError(f"Unexpected quantize `{quantize}`")
375
376
377
378
379

                    module._parameters[param_name] = tensor
                    if name == "model.decoder.embed_tokens.weight":
                        model.lm_head._parameters["weight"] = tensor

380
381
382
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
383
384
385
386
387
388
389
390
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
391
392
393
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
394
395

        return logits, outputs.past_key_values