utils.py 8.09 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
from functools import partial
11
12
13
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Nicolas Patry's avatar
Nicolas Patry committed
14
15
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
16
from typing import List, Optional, Tuple
17
from transformers import PreTrainedTokenizerBase
18
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
20
21
22
23
24
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

25
26
from text_generation.pb import generate_pb2

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28

class Sampling:
29
    def __init__(self, seed: int, device: str = "cpu"):
30
        self.generator = torch.Generator(device)
31
32
        self.generator.manual_seed(seed)
        self.seed = seed
33

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34
35
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
36
37
38
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
42
43
44
45
46
47
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
48
    def __init__(
49
50
51
52
53
        self,
        temperature=1.0,
        top_k=None,
        top_p=None,
        do_sample=False,
54
        seed=0,
55
        device="cpu",
56
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True

        self.warpers = warpers
73
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
74
75

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
76
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
77
        scores = self.warpers(input_ids, scores)
OlivierDehaene's avatar
OlivierDehaene committed
78
79
80
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
82
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
83

84
    @classmethod
85
86
87
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
88
89
90
91
92
        return NextTokenChooser(
            temperature=pb.temperature,
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
93
            seed=pb.seed,
94
            device=str(device),
95
96
97
98
        )


class StopSequenceCriteria:
99
100
101
102
103
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
104
105
106
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
107
108

class StoppingCriteria:
109
    def __init__(
110
111
112
113
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
114
    ):
115
        self.eos_token_id = eos_token_id
116
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
117
118
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
119
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
120

121
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
122
123
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
124
125
            return True, "length"

126
127
128
129
        if last_token == self.eos_token_id:
            return True, "eos_token"

        self.current_output += last_output
130
        for stop_sequence_criteria in self.stop_sequence_criterias:
131
            if stop_sequence_criteria(self.current_output):
132
133
134
135
136
137
                return True, "stop_sequence"

        return False, None

    @classmethod
    def from_pb(
138
139
140
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
141
    ) -> "StoppingCriteria":
142
143
144
145
146
147
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
169
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
170
171
172
173
174
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


175
def weight_hub_files(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
176
177
    """Get the safetensors filenames on the hub"""
    api = HfApi()
178
    info = api.model_info(model_name, revision=revision)
179
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
180
181
182
    return filenames


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def try_to_load_from_cache(model_name, revision, filename):
    """Try to load a file from the Hugging Face cache"""
    if revision is None:
        revision = "main"

    object_id = model_name.replace("/", "--")
    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"

    if not repo_cache.is_dir():
        # No cache for this model
        return None

    refs_dir = repo_cache / "refs"
    snapshots_dir = repo_cache / "snapshots"
    no_exist_dir = repo_cache / ".no_exist"

    # Resolve refs (for instance to convert main to the associated commit sha)
    if refs_dir.is_dir():
        revision_file = refs_dir / revision
        if revision_file.exists():
            with revision_file.open() as f:
                revision = f.read()

    # Check if file is cached as "no_exist"
    if (no_exist_dir / revision / filename).is_file():
        return _CACHED_NO_EXIST

    # Check if revision folder exists
    if not snapshots_dir.exists():
        return None
    cached_shas = os.listdir(snapshots_dir)
    if revision not in cached_shas:
        # No cache for this revision and we won't try to return a random revision
        return None

    # Check if file exists in cache
    cached_file = snapshots_dir / revision / filename
    return str(cached_file) if cached_file.is_file() else None


def weight_files(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
224
    """Get the local safetensors filenames"""
225
    filenames = weight_hub_files(model_name, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
226
227
    files = []
    for filename in filenames:
228
229
230
        cache_file = try_to_load_from_cache(
            model_name, revision=revision, filename=filename
        )
Nicolas Patry's avatar
Nicolas Patry committed
231
232
233
234
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
235
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
236
237
238
239
240
241
            )
        files.append(cache_file)

    return files


242
def download_weights(model_name, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
243
    """Download the safetensors files from the hub"""
244
    filenames = weight_hub_files(model_name, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
245
246

    download_function = partial(
247
248
249
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
250
    )
251
252

    executor = ThreadPoolExecutor(max_workers=5)
253
    futures = [
254
255
        executor.submit(download_function, filename=filename, revision=revision)
        for filename in filenames
256
257
    ]
    files = [
258
259
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
260
    ]
261

Nicolas Patry's avatar
Nicolas Patry committed
262
    return files