utils.py 8.94 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
from functools import partial
11
12
13
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Nicolas Patry's avatar
Nicolas Patry committed
14
15
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
16
from typing import List, Optional, Tuple
17
from transformers import PreTrainedTokenizerBase
18
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
    LogitsProcessorList,
20
    RepetitionPenaltyLogitsProcessor,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21
22
23
24
25
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

26
from text_generation.pb import generate_pb2
27
from text_generation.pb.generate_pb2 import FinishReason
28

29
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
30

31

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
class Sampling:
33
    def __init__(self, seed: int, device: str = "cpu"):
34
        self.generator = torch.Generator(device)
35
36
        self.generator.manual_seed(seed)
        self.seed = seed
37

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
40
41
42
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43
44
45
46
47
48
49
50
51
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
52
    def __init__(
53
54
        self,
        temperature=1.0,
55
        repetition_penalty=1.0,
56
57
58
        top_k=None,
        top_p=None,
        do_sample=False,
59
        seed=0,
60
        device="cpu",
61
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True
76
77
        if repetition_penalty is not None and repetition_penalty != 1.0:
            warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79

        self.warpers = warpers
80
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
82

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
83
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
84
        scores = self.warpers(input_ids, scores)
85

OlivierDehaene's avatar
OlivierDehaene committed
86
87
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
88

OlivierDehaene's avatar
OlivierDehaene committed
89
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
90
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
91
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92

93
    @classmethod
94
95
96
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
97
98
        return NextTokenChooser(
            temperature=pb.temperature,
99
            repetition_penalty=pb.repetition_penalty,
100
101
102
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
103
            seed=pb.seed,
104
            device=device,
105
106
107
108
        )


class StopSequenceCriteria:
109
110
111
112
113
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
114
115
116
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
117
118

class StoppingCriteria:
119
    def __init__(
120
121
122
123
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
124
    ):
125
        self.eos_token_id = eos_token_id
126
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
127
128
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
129
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130

131
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
132
133
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
134
            return True, FinishReason.FINISH_REASON_LENGTH
135

136
        if last_token == self.eos_token_id:
137
            return True, FinishReason.FINISH_REASON_EOS_TOKEN
138
139

        self.current_output += last_output
140
        for stop_sequence_criteria in self.stop_sequence_criterias:
141
            if stop_sequence_criteria(self.current_output):
142
                return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
143
144
145
146
147

        return False, None

    @classmethod
    def from_pb(
148
149
150
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
151
    ) -> "StoppingCriteria":
152
153
154
155
156
157
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
158
159
160
161
162
163
164


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
165
        from torch.distributed import ProcessGroupNCCL
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
166
167
168
169
170
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
171
172
173
        options = ProcessGroupNCCL.Options()
        options.is_high_priority_stream = True
        options._timeout = timedelta(seconds=60)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
174
175
    else:
        backend = "gloo"
176
        options = None
177

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
178
179
180
181
182
    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
183
        timeout=timedelta(seconds=60),
184
        pg_options=options
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185
186
    )

187
    return torch.distributed.group.WORLD, rank, world_size
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
188
189


190
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
191
192
    """Get the safetensors filenames on the hub"""
    api = HfApi()
193
    info = api.model_info(model_id, revision=revision)
194
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
195
196
197
    return filenames


198
def try_to_load_from_cache(model_id, revision, filename):
199
200
201
202
    """Try to load a file from the Hugging Face cache"""
    if revision is None:
        revision = "main"

203
    object_id = model_id.replace("/", "--")
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"

    if not repo_cache.is_dir():
        # No cache for this model
        return None

    refs_dir = repo_cache / "refs"
    snapshots_dir = repo_cache / "snapshots"
    no_exist_dir = repo_cache / ".no_exist"

    # Resolve refs (for instance to convert main to the associated commit sha)
    if refs_dir.is_dir():
        revision_file = refs_dir / revision
        if revision_file.exists():
            with revision_file.open() as f:
                revision = f.read()

    # Check if file is cached as "no_exist"
    if (no_exist_dir / revision / filename).is_file():
        return _CACHED_NO_EXIST

    # Check if revision folder exists
    if not snapshots_dir.exists():
        return None
    cached_shas = os.listdir(snapshots_dir)
    if revision not in cached_shas:
        # No cache for this revision and we won't try to return a random revision
        return None

    # Check if file exists in cache
    cached_file = snapshots_dir / revision / filename
    return str(cached_file) if cached_file.is_file() else None


238
def weight_files(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
239
    """Get the local safetensors filenames"""
240
241
242
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

243
    filenames = weight_hub_files(model_id, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
244
245
    files = []
    for filename in filenames:
246
        cache_file = try_to_load_from_cache(
247
            model_id, revision=revision, filename=filename
248
        )
Nicolas Patry's avatar
Nicolas Patry committed
249
250
        if cache_file is None:
            raise LocalEntryNotFoundError(
251
                f"File {filename} of model {model_id} not found in "
Nicolas Patry's avatar
Nicolas Patry committed
252
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
253
                f"Please run `text-generation-server download-weights {model_id}` first."
Nicolas Patry's avatar
Nicolas Patry committed
254
255
256
257
258
259
            )
        files.append(cache_file)

    return files


260
def download_weights(model_id, revision=None, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
261
    """Download the safetensors files from the hub"""
262
263
264
    if WEIGHTS_CACHE_OVERRIDE is not None:
        return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))

265
    filenames = weight_hub_files(model_id, revision, extension)
Nicolas Patry's avatar
Nicolas Patry committed
266
267

    download_function = partial(
268
        hf_hub_download,
269
        repo_id=model_id,
270
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
271
    )
272
273

    executor = ThreadPoolExecutor(max_workers=5)
274
    futures = [
275
276
        executor.submit(download_function, filename=filename, revision=revision)
        for filename in filenames
277
278
    ]
    files = [
279
280
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
281
    ]
282

Nicolas Patry's avatar
Nicolas Patry committed
283
    return files