utils.py 6.57 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
14
from typing import List, Optional, Tuple
15
from transformers import PreTrainedTokenizerBase
16
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17
18
19
20
21
22
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

23
24
from text_generation.pb import generate_pb2

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26

class Sampling:
27
    def __init__(self, seed: int, device: str = "cpu"):
28
        self.generator = torch.Generator(device)
29
30
        self.generator.manual_seed(seed)
        self.seed = seed
31

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
33
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
34
35
36
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
38
39
40
41
42
43
44
45
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
46
    def __init__(
47
48
49
50
51
        self,
        temperature=1.0,
        top_k=None,
        top_p=None,
        do_sample=False,
52
        seed=0,
53
        device="cpu",
54
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True

        self.warpers = warpers
71
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72
73

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
74
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75
        scores = self.warpers(input_ids, scores)
OlivierDehaene's avatar
OlivierDehaene committed
76
77
78
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
80
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81

82
    @classmethod
83
84
85
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
86
87
88
89
90
        return NextTokenChooser(
            temperature=pb.temperature,
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
91
            seed=pb.seed,
92
            device=str(device),
93
94
95
96
        )


class StopSequenceCriteria:
97
98
99
100
101
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
102
103
104
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
105
106

class StoppingCriteria:
107
    def __init__(
108
109
110
111
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
112
    ):
113
        self.eos_token_id = eos_token_id
114
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
115
116
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
117
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
118

119
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
120
121
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
122
123
            return True, "length"

124
125
126
127
        if last_token == self.eos_token_id:
            return True, "eos_token"

        self.current_output += last_output
128
        for stop_sequence_criteria in self.stop_sequence_criterias:
129
            if stop_sequence_criteria(self.current_output):
130
131
132
133
134
135
                return True, "stop_sequence"

        return False, None

    @classmethod
    def from_pb(
136
137
138
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
139
    ) -> "StoppingCriteria":
140
141
142
143
144
145
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
167
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
168
169
170
171
172
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


173
def weight_hub_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
174
175
176
    """Get the safetensors filenames on the hub"""
    api = HfApi()
    info = api.model_info(model_name)
177
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
178
179
180
    return filenames


181
def weight_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
182
    """Get the local safetensors filenames"""
183
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
184
185
186
187
188
189
190
    files = []
    for filename in filenames:
        cache_file = try_to_load_from_cache(model_name, filename=filename)
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
191
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
192
193
194
195
196
197
            )
        files.append(cache_file)

    return files


198
def download_weights(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
199
    """Download the safetensors files from the hub"""
200
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
201
202

    download_function = partial(
203
204
205
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
206
    )
207
208

    executor = ThreadPoolExecutor(max_workers=5)
209
210
211
212
    futures = [
        executor.submit(download_function, filename=filename) for filename in filenames
    ]
    files = [
213
214
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
215
    ]
216

Nicolas Patry's avatar
Nicolas Patry committed
217
    return files