galactica.py 8.15 KB
Newer Older
1
2
3
4
import re
import torch
import torch.distributed

5
from typing import List, Optional, Type
6

7
8
9
10
11
from transformers import (
    AutoTokenizer,
    AutoConfig,
    PreTrainedTokenizerBase,
)
12
13
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
14
from text_generation_server.pb import generate_pb2
15
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
16
from text_generation_server.utils import (
17
18
19
20
    NextTokenChooser,
    StoppingCriteria,
    initialize_torch_distributed,
    weight_files,
21
    Weights,
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
)

# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
73
74
75
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
76
        dtype: torch.dtype,
77
        device: torch.device,
78
    ) -> "GalacticaCausalLMBatch":
79
80
81
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
82
        prefix_offsets = []
83
        top_n_tokens = []
84
        read_offsets = []
85
        requests_idx_mapping = {}
86
87

        # Parse batch
88
        max_truncation = 0
89
        padding_right_offset = 0
90
        max_decode_tokens = 0
91
92
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
93
94
            # Add escape_custom_split_sequence to the CausalLMBatch logic
            inputs.append(escape_custom_split_sequence(r.inputs))
OlivierDehaene's avatar
OlivierDehaene committed
95
96
97
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
98
99
100
101
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
102
            top_n_tokens.append(r.top_n_tokens)
103
            max_truncation = max(max_truncation, r.truncate)
104
            max_decode_tokens += stopping_criteria.max_new_tokens
105
106
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
107
108
109
            )

        tokenized_inputs = tokenizer(
110
111
112
113
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
114
115
            truncation=True,
            max_length=max_truncation,
116
        ).to(device)
117
118
119
120
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
121
122
123
124

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

125
126
127
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
128
            (pb.size, max_input_length + padding_right_offset)
129
130
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
131
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
132

133
134
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
135
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
136
137
138
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
139

140
141
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

142
143
144
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
145
            requests_idx_mapping=requests_idx_mapping,
146
147
            input_ids=input_ids,
            attention_mask=attention_mask,
148
            position_ids=position_ids,
149
            past_key_values=None,
150
151
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
152
153
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
154
155
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
156
157
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
158
            max_input_length=max_input_length.item(),
159
            padding_right_offset=padding_right_offset,
160
            max_tokens=max_tokens,
161
162
163
        )


164
class GalacticaSharded(CausalLM):
165
    def __init__(
166
167
168
169
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
170
        use_medusa: Optional[str] = None,
171
        dtype: Optional[torch.dtype] = None,
172
        trust_remote_code: bool = False,
173
    ):
174
        self.process_group, rank, world_size = initialize_torch_distributed()
175
        if torch.cuda.is_available():
176
            device = torch.device(f"cuda:{rank}")
177
            dtype = torch.float16 if dtype is None else dtype
178
179
        else:
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
180
            dtype = torch.float32 if dtype is None else dtype
181

182
        tokenizer = AutoTokenizer.from_pretrained(
183
184
185
186
187
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
188
        )
189

190
        config = AutoConfig.from_pretrained(
191
192
193
194
            model_id,
            revision=revision,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
195
        )
196
        config.quantize = quantize
197
        tokenizer.pad_token_id = config.pad_token_id
OlivierDehaene's avatar
OlivierDehaene committed
198
        config.use_medusa = use_medusa
199
200

        torch.distributed.barrier(group=self.process_group)
201
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
202
203
204
        weights = Weights(
            filenames, device=device, dtype=dtype, process_group=self.process_group
        )
205
        if config.quantize == "gptq":
OlivierDehaene's avatar
OlivierDehaene committed
206
            weights._set_gptq_params(model_id, revision)
207

208
        model = OPTForCausalLM(config, weights)
209
210
211

        torch.distributed.barrier(group=self.process_group)
        super(CausalLM, self).__init__(
212
            model=model,
213
            tokenizer=tokenizer,
214
215
            requires_padding=True,
            dtype=dtype,
216
            device=device,
217
218
            rank=rank,
            world_size=world_size,
219
220
        )

221
222
223
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return GalacticaCausalLMBatch
224

225
226
227
228
229
    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
        )
230

231
232
233
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
OlivierDehaene's avatar
OlivierDehaene committed
234
        outputs, speculative_logits = self.model.forward(
235
236
237
238
239
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
        )
OlivierDehaene's avatar
OlivierDehaene committed
240
        return outputs.logits, speculative_logits, outputs.past_key_values