galactica.py 8.32 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type
6

7
8
9
10
11
from transformers import (
    AutoTokenizer,
    AutoConfig,
    PreTrainedTokenizerBase,
)
12
13
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
14
from text_generation_server.pb import generate_pb2
15
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
16
from text_generation_server.utils import (
17
18
19
20
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
21
    Weights,
22
)
Daniël de Kok's avatar
Daniël de Kok committed
23
from text_generation_server.utils.chunks import concat_text_chunks
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
74
75
76
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
77
        dtype: torch.dtype,
78
        device: torch.device,
79
    ) -> "GalacticaCausalLMBatch":
80
81
82
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
83
        prefix_offsets = []
84
        top_n_tokens = []
85
        read_offsets = []
86
        requests_idx_mapping = {}
87
88

        # Parse batch
89
        max_truncation = 0
90
        padding_right_offset = 0
91
        max_decode_tokens = 0
92
93
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
94
            # Add escape_custom_split_sequence to the CausalLMBatch logic
Daniël de Kok's avatar
Daniël de Kok committed
95
96
97
            inputs.append(
                escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
            )
OlivierDehaene's avatar
OlivierDehaene committed
98
99
100
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
101
102
103
104
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
105
            top_n_tokens.append(r.top_n_tokens)
106
            max_truncation = max(max_truncation, r.truncate)
107
            max_decode_tokens += stopping_criteria.max_new_tokens
108
109
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
110
111
112
            )

        tokenized_inputs = tokenizer(
113
114
115
116
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
117
118
            truncation=True,
            max_length=max_truncation,
119
        ).to(device)
120
121
122
123
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
124
125
126
127

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

128
129
130
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
131
            (pb.size, max_input_length + padding_right_offset)
132
133
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
134
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
135

136
137
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
138
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
139
140
141
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
142

143
144
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

145
146
147
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
148
            requests_idx_mapping=requests_idx_mapping,
149
150
            input_ids=input_ids,
            attention_mask=attention_mask,
151
            position_ids=position_ids,
152
            past_key_values=None,
153
154
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
155
156
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
157
158
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
159
160
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
161
            max_input_length=max_input_length.item(),
162
            padding_right_offset=padding_right_offset,
163
            max_tokens=max_tokens,
164
165
166
        )


167
class GalacticaSharded(CausalLM):
168
    def __init__(
169
170
171
172
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
173
        speculator: Optional[str] = None,
174
        dtype: Optional[torch.dtype] = None,
175
        trust_remote_code: bool = False,
176
    ):
177
        self.process_group, rank, world_size = initialize_torch_distributed()
178
        if torch.cuda.is_available():
179
            device = torch.device(f"cuda:{rank}")
180
            dtype = torch.float16 if dtype is None else dtype
181
182
        else:
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
183
            dtype = torch.float32 if dtype is None else dtype
184

185
        tokenizer = AutoTokenizer.from_pretrained(
186
187
188
189
190
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
191
        )
192

193
        config = AutoConfig.from_pretrained(
194
195
196
197
            model_id,
            revision=revision,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
198
        )
199
        config.quantize = quantize
200
        tokenizer.pad_token_id = config.pad_token_id
Nicolas Patry's avatar
Nicolas Patry committed
201
        config.speculator = speculator
202
203

        torch.distributed.barrier(group=self.process_group)
204
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
205
206
207
        weights = Weights(
            filenames, device=device, dtype=dtype, process_group=self.process_group
        )
208
        if config.quantize in ["gptq", "marlin"]:
OlivierDehaene's avatar
OlivierDehaene committed
209
            weights._set_gptq_params(model_id, revision)
210

211
        model = OPTForCausalLM(config, weights)
212
213
214

        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
drbh's avatar
drbh committed
215
            model_id=model_id,
216
            model=model,
217
            tokenizer=tokenizer,
218
219
            requires_padding=True,
            dtype=dtype,
220
            device=device,
221
222
            rank=rank,
            world_size=world_size,
223
224
        )

225
226
227
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch
228

229
230
231
232
233
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
        )
234

235
236
237
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
OlivierDehaene's avatar
OlivierDehaene committed
238
        outputs, speculative_logits = self.model.forward(
239
240
241
242
243
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
OlivierDehaene's avatar
OlivierDehaene committed
244
        return outputs.logits, speculative_logits, outputs.past_key_values