utils.py 7.42 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from threading import Lock
from typing import Any, Callable, Generator, List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from pydantic import BaseModel, Field

try:
    from transformers.generation_logits_process import (
        LogitsProcessorList,
        TemperatureLogitsWarper,
        TopKLogitsWarper,
        TopPLogitsWarper,
    )
except ImportError:
    from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper


def prepare_logits_processor(top_k: Optional[int] = None,
                             top_p: Optional[float] = None,
                             temperature: Optional[float] = None) -> LogitsProcessorList:
    processor_list = LogitsProcessorList()
    if temperature is not None and temperature != 1.0:
        processor_list.append(TemperatureLogitsWarper(temperature))
    if top_k is not None and top_k != 0:
        processor_list.append(TopKLogitsWarper(top_k))
    if top_p is not None and top_p < 1.0:
        processor_list.append(TopPLogitsWarper(top_p))
    return processor_list


def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
    if dist.is_initialized() and dist.get_world_size() > 1:
        # consider DP
        unfinished_sequences = unfinished_sequences.clone()
        dist.all_reduce(unfinished_sequences)
    return unfinished_sequences.max() == 0


def sample_streamingly(model: nn.Module,
                       input_ids: torch.Tensor,
                       max_generate_tokens: int,
                       early_stopping: bool = False,
                       eos_token_id: Optional[int] = None,
                       pad_token_id: Optional[int] = None,
                       top_k: Optional[int] = None,
                       top_p: Optional[float] = None,
                       temperature: Optional[float] = None,
                       prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
                       update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
                       **model_kwargs) -> Generator:

    logits_processor = prepare_logits_processor(top_k, top_p, temperature)
    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

    for _ in range(max_generate_tokens):
        model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
            'input_ids': input_ids
        }
        outputs = model(**model_inputs)

        next_token_logits = outputs['logits'][:, -1, :]
        # pre-process distribution
        next_token_logits = logits_processor(input_ids, next_token_logits)
        # sample
        probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # finished sentences should have their next token be a padding token
        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        yield next_tokens

        # update generated ids, model inputs for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if update_model_kwargs_fn is not None:
            model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)

        # if eos_token was found in one sentence, set sentence to finished
        if eos_token_id is not None:
            unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

        # stop when each sentence is finished if early_stopping=True
        if early_stopping and _is_sequence_finished(unfinished_sequences):
            break


def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
    if "past_key_values" in outputs:
        model_kwargs["past"] = outputs["past_key_values"]
    else:
        model_kwargs["past"] = None

    # update token_type_ids with last value
    if "token_type_ids" in model_kwargs:
        token_type_ids = model_kwargs["token_type_ids"]
        model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

    # update attention mask
    if "attention_mask" in model_kwargs:
        attention_mask = model_kwargs["attention_mask"]
        model_kwargs["attention_mask"] = torch.cat(
            [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

    return model_kwargs


class Dialogue(BaseModel):
    instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
    response: str = Field(example='')


def _format_dialogue(instruction: str, response: str = ''):
    return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'


class ChatPromptProcessor:

    def __init__(self, tokenizer, context: str, max_len: int = 2048):
        self.tokenizer = tokenizer
        self.context = context
        self.max_len = max_len
        # These will be initialized after the first call of preprocess_prompt()
        self.context_len: Optional[int] = None
        self.dialogue_placeholder_len: Optional[int] = None

    def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
        if self.context_len is None:
            self.context_len = len(self.tokenizer(self.context)['input_ids'])
        if self.dialogue_placeholder_len is None:
            self.dialogue_placeholder_len = len(
                self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
        prompt = self.context
        # the last dialogue must be in the prompt
        last_dialogue = history.pop()
        # the response of the last dialogue is empty
        assert last_dialogue.response == ''
        if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
               ['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
            # to avoid truncate placeholder, apply truncate to the original instruction
            instruction_truncated = self.tokenizer(last_dialogue.instruction,
                                                   add_special_tokens=False,
                                                   truncation=True,
                                                   max_length=(self.max_len - max_new_tokens - self.context_len -
                                                               self.dialogue_placeholder_len))['input_ids']
            instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
            prompt += _format_dialogue(instruction_truncated)
            return prompt

        res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])

        rows = []
        for dialogue in history[::-1]:
            text = _format_dialogue(dialogue.instruction, dialogue.response)
            cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
            if res_len - cur_len < 0:
                break
            res_len -= cur_len
            rows.insert(0, text)
        prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
        return prompt


class LockedIterator:

    def __init__(self, it, lock: Lock) -> None:
        self.lock = lock
        self.it = iter(it)

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)