"vscode:/vscode.git/clone" did not exist on "7c1692aa90b208b1925292f155908fa745a52335"
galactica.py 5.35 KB
Newer Older
1
2
3
4
5
import re
import torch
import torch.distributed


6
7
8
from transformers import (
    PreTrainedTokenizerBase,
)
9
from text_generation_server.models.causal_lm import CausalLMBatch
10
from text_generation_server.pb import generate_pb2
11
from text_generation_server.utils import (
12
13
14
    NextTokenChooser,
    StoppingCriteria,
)
Daniël de Kok's avatar
Daniël de Kok committed
15
from text_generation_server.utils.chunks import concat_text_chunks
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].
    Parameters
    ----------
    n : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization
    Parameters
    ----------
    text : str
        Input text to split
    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


# END CREDIT


class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
69
        dtype: torch.dtype,
70
        device: torch.device,
71
    ) -> "GalacticaCausalLMBatch":
72
73
74
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
75
        prefix_offsets = []
76
        top_n_tokens = []
77
        read_offsets = []
78
        requests_idx_mapping = {}
79
80

        # Parse batch
81
        max_truncation = 0
82
        padding_right_offset = 0
83
        max_decode_tokens = 0
84
85
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
86
            # Add escape_custom_split_sequence to the CausalLMBatch logic
Daniël de Kok's avatar
Daniël de Kok committed
87
88
89
            inputs.append(
                escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
            )
OlivierDehaene's avatar
OlivierDehaene committed
90
91
92
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
93
94
95
96
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
97
            top_n_tokens.append(r.top_n_tokens)
98
            max_truncation = max(max_truncation, r.truncate)
99
            max_decode_tokens += stopping_criteria.max_new_tokens
100
101
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
102
103
104
            )

        tokenized_inputs = tokenizer(
105
106
107
108
            inputs,
            return_tensors="pt",
            padding=True,
            return_token_type_ids=False,
109
110
            truncation=True,
            max_length=max_truncation,
111
        ).to(device)
112
113
114
115
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
116
117
118
119

        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

120
121
122
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
123
            (pb.size, max_input_length + padding_right_offset)
124
125
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
126
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
127

128
129
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
130
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
131
132
133
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
134

135
136
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

137
138
139
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
140
            requests_idx_mapping=requests_idx_mapping,
141
142
            input_ids=input_ids,
            attention_mask=attention_mask,
143
            position_ids=position_ids,
144
            past_key_values=None,
145
146
            all_input_ids=list(all_input_ids),
            input_lengths=input_lengths.tolist(),
147
148
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
149
150
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
151
152
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
153
            max_input_length=max_input_length.item(),
154
            padding_right_offset=padding_right_offset,
155
            max_tokens=max_tokens,
156
        )