utils.py 8.76 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
from functools import partial
11
12
13
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Nicolas Patry's avatar
Nicolas Patry committed
14
15
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
16
from typing import List, Optional, Tuple
17
from transformers import PreTrainedTokenizerBase
18
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
    LogitsProcessorList,
20
    RepetitionPenaltyLogitsProcessor,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21
22
23
24
25
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

26
from text_generation.pb import generate_pb2
27
from text_generation.pb.generate_pb2 import FinishReason
28

29
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
30

31

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
class Sampling:
33
    def __init__(self, seed: int, device: str = "cpu"):
34
        self.generator = torch.Generator(device)
35
36
        self.generator.manual_seed(seed)
        self.seed = seed
37

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
40
41
42
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43
44
45
46
47
48
49
50
51
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
52
    def __init__(
53
54
        self,
        temperature=1.0,
55
        repetition_penalty=1.0,
56
57
58
        top_k=None,
        top_p=None,
        do_sample=False,
59
        seed=0,
60
        device="cpu",
61
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True
76
77
        if repetition_penalty is not None and repetition_penalty != 1.0:
            warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79

        self.warpers = warpers
80
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
82

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
83
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
84
        scores = self.warpers(input_ids, scores)
85

OlivierDehaene's avatar
OlivierDehaene committed
86
87
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
88

OlivierDehaene's avatar
OlivierDehaene committed
89
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
90
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
91
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92

93
    @classmethod
94
95
96
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
97
98
        return NextTokenChooser(
            temperature=pb.temperature,
99
            repetition_penalty=pb.repetition_penalty,
100
101
102
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
103
            seed=pb.seed,
104
            device=device,
105
106
107
108
        )


class StopSequenceCriteria:
109
110
111
112
113
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
114
115
116
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
117
118

class StoppingCriteria:
119
    def __init__(
120
121
122
123
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
124
    ):
125
        self.eos_token_id = eos_token_id
126
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
127
128
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
129
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130

131
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
132
133
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
134
            return True, FinishReason.FINISH_REASON_LENGTH
135

136
        if last_token == self.eos_token_id:
137
            return True, FinishReason.FINISH_REASON_EOS_TOKEN
138
139

        self.current_output += last_output
140
        for stop_sequence_criteria in self.stop_sequence_criterias:
141
            if stop_sequence_criteria(self.current_output):
142
                return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
143
144
145
146
147

        return False, None

    @classmethod
    def from_pb(
148
149
150
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
151
    ) -> "StoppingCriteria":
152
153
154
155
156
157
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
179
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
180
181
182
183
184
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


185
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
186
187
    """Get the safetensors filenames on the hub"""
    api = HfApi()
188
    info = api.model_info(model_id, revision=revision)
189
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
190
191
192
    return filenames


193
def try_to_load_from_cache(model_id, revision, filename):
194
195
196
197
    """Try to load a file from the Hugging Face cache"""
    if revision is None:
        revision = "main"

198
    object_id = model_id.replace("/", "--")
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"

    if not repo_cache.is_dir():
        # No cache for this model
        return None

    refs_dir = repo_cache / "refs"
    snapshots_dir = repo_cache / "snapshots"
    no_exist_dir = repo_cache / ".no_exist"

    # Resolve refs (for instance to convert main to the associated commit sha)
    if refs_dir.is_dir():
        revision_file = refs_dir / revision
        if revision_file.exists():
            with revision_file.open() as f:
                revision = f.read()

    # Check if file is cached as "no_exist"
    if (no_exist_dir / revision / filename).is_file():
        return _CACHED_NO_EXIST

    # Check if revision folder exists
    if not snapshots_dir.exists():
        return None
    cached_shas = os.listdir(snapshots_dir)
    if revision not in cached_shas:
        # No cache for this revision and we won't try to return a random revision
        return None

    # Check if file exists in cache
    cached_file = snapshots_dir / revision / filename
    return str(cached_file) if cached_file.is_file() else None


233
def weight_files(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
234
    """Get the local safetensors filenames"""
235
236
237
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

238
    filenames = weight_hub_files(model_id, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
239
240
    files = []
    for filename in filenames:
241
        cache_file = try_to_load_from_cache(
242
            model_id, revision=revision, filename=filename
243
        )
Nicolas Patry's avatar
Nicolas Patry committed
244
245
        if cache_file is None:
            raise LocalEntryNotFoundError(
246
                f"File {filename} of model {model_id} not found in "
Nicolas Patry's avatar
Nicolas Patry committed
247
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
248
                f"Please run `text-generation-server download-weights {model_id}` first."
Nicolas Patry's avatar
Nicolas Patry committed
249
250
251
252
253
254
            )
        files.append(cache_file)

    return files


255
def download_weights(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
256
    """Download the safetensors files from the hub"""
257
258
259
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

260
    filenames = weight_hub_files(model_id, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
261
262

    download_function = partial(
263
        hf_hub_download,
264
        repo_id=model_id,
265
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
266
    )
267
268

    executor = ThreadPoolExecutor(max_workers=5)
269
    futures = [
270
271
        executor.submit(download_function, filename=filename, revision=revision)
        for filename in filenames
272
273
    ]
    files = [
274
275
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
276
    ]
277

Nicolas Patry's avatar
Nicolas Patry committed
278
    return files