utils.py 8.66 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
from functools import partial
11
12
13
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Nicolas Patry's avatar
Nicolas Patry committed
14
15
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
16
from typing import List, Optional, Tuple
17
from transformers import PreTrainedTokenizerBase
18
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
    LogitsProcessorList,
20
    RepetitionPenaltyLogitsProcessor,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21
22
23
24
25
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

26
27
from text_generation.pb import generate_pb2

28
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30

class Sampling:
31
    def __init__(self, seed: int, device: str = "cpu"):
32
        self.generator = torch.Generator(device)
33
34
        self.generator.manual_seed(seed)
        self.seed = seed
35

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
36
37
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
38
39
40
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
44
45
46
47
48
49
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
50
    def __init__(
51
52
        self,
        temperature=1.0,
53
        repetition_penalty=1.0,
54
55
56
        top_k=None,
        top_p=None,
        do_sample=False,
57
        seed=0,
58
        device="cpu",
59
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True
74
75
        if repetition_penalty is not None and repetition_penalty != 1.0:
            warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76
77

        self.warpers = warpers
78
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
81
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
        scores = self.warpers(input_ids, scores)
83

OlivierDehaene's avatar
OlivierDehaene committed
84
85
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
86

OlivierDehaene's avatar
OlivierDehaene committed
87
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
88
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
89
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
90

91
    @classmethod
92
93
94
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
95
96
        return NextTokenChooser(
            temperature=pb.temperature,
97
            repetition_penalty=pb.repetition_penalty,
98
99
100
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
101
            seed=pb.seed,
102
            device=device,
103
104
105
106
        )


class StopSequenceCriteria:
107
108
109
110
111
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
112
113
114
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
115
116

class StoppingCriteria:
117
    def __init__(
118
119
120
121
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
122
    ):
123
        self.eos_token_id = eos_token_id
124
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
125
126
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
127
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
128

129
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130
131
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
132
133
            return True, "length"

134
135
136
137
        if last_token == self.eos_token_id:
            return True, "eos_token"

        self.current_output += last_output
138
        for stop_sequence_criteria in self.stop_sequence_criterias:
139
            if stop_sequence_criteria(self.current_output):
140
141
142
143
144
145
                return True, "stop_sequence"

        return False, None

    @classmethod
    def from_pb(
146
147
148
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
149
    ) -> "StoppingCriteria":
150
151
152
153
154
155
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
177
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
178
179
180
181
182
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


183
def weight_hub_files(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
184
185
    """Get the safetensors filenames on the hub"""
    api = HfApi()
186
    info = api.model_info(model_name, revision=revision)
187
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
188
189
190
    return filenames


191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def try_to_load_from_cache(model_name, revision, filename):
    """Try to load a file from the Hugging Face cache"""
    if revision is None:
        revision = "main"

    object_id = model_name.replace("/", "--")
    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"

    if not repo_cache.is_dir():
        # No cache for this model
        return None

    refs_dir = repo_cache / "refs"
    snapshots_dir = repo_cache / "snapshots"
    no_exist_dir = repo_cache / ".no_exist"

    # Resolve refs (for instance to convert main to the associated commit sha)
    if refs_dir.is_dir():
        revision_file = refs_dir / revision
        if revision_file.exists():
            with revision_file.open() as f:
                revision = f.read()

    # Check if file is cached as "no_exist"
    if (no_exist_dir / revision / filename).is_file():
        return _CACHED_NO_EXIST

    # Check if revision folder exists
    if not snapshots_dir.exists():
        return None
    cached_shas = os.listdir(snapshots_dir)
    if revision not in cached_shas:
        # No cache for this revision and we won't try to return a random revision
        return None

    # Check if file exists in cache
    cached_file = snapshots_dir / revision / filename
    return str(cached_file) if cached_file.is_file() else None


def weight_files(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
232
    """Get the local safetensors filenames"""
233
234
235
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

236
    filenames = weight_hub_files(model_name, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
237
238
    files = []
    for filename in filenames:
239
240
241
        cache_file = try_to_load_from_cache(
            model_name, revision=revision, filename=filename
        )
Nicolas Patry's avatar
Nicolas Patry committed
242
243
244
245
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
246
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
247
248
249
250
251
252
            )
        files.append(cache_file)

    return files


253
def download_weights(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
254
    """Download the safetensors files from the hub"""
255
256
257
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

258
    filenames = weight_hub_files(model_name, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
259
260

    download_function = partial(
261
262
263
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
264
    )
265
266

    executor = ThreadPoolExecutor(max_workers=5)
267
    futures = [
268
269
        executor.submit(download_function, filename=filename, revision=revision)
        for filename in filenames
270
271
    ]
    files = [
272
273
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
274
    ]
275

Nicolas Patry's avatar
Nicolas Patry committed
276
    return files