utils.py 6.83 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
14
from typing import List, Optional, Tuple
15
from transformers import PreTrainedTokenizerBase
16
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17
18
19
20
21
22
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

23
24
from text_generation.pb import generate_pb2

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26

class Sampling:
27
28
    def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
        self.generator = torch.Generator(device)
29
30
31
32
33
        if seed is not None:
            self.generator.manual_seed(seed)
        else:
            self.generator.seed()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34
35
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
36
37
38
        next_tokens = torch.multinomial(
            probs, num_samples=1, generator=self.generator
        ).squeeze(1)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
        return next_tokens

41
42
43
44
    @property
    def seed(self) -> int:
        return self.generator.initial_seed()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
49
50
51

class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
52
    def __init__(
53
54
55
56
57
58
59
        self,
        temperature=1.0,
        top_k=None,
        top_p=None,
        do_sample=False,
        seed=None,
        device="cpu",
60
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True

        self.warpers = warpers
77
        self.choice = Sampling(seed, device) if sampling else Greedy()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
80
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
        scores = self.warpers(input_ids, scores)
OlivierDehaene's avatar
OlivierDehaene committed
82
83
84
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
86
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87

88
    @classmethod
89
90
91
    def from_pb(
        cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
    ) -> "NextTokenChooser":
92
93
        # handle protobuf making default values 0
        seed = pb.seed if pb.HasField("seed") else None
94
95
96
97
98
        return NextTokenChooser(
            temperature=pb.temperature,
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
99
            seed=seed,
100
            device=str(device),
101
102
103
104
        )


class StopSequenceCriteria:
105
106
107
108
109
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
110
111
112
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
113
114

class StoppingCriteria:
115
    def __init__(
116
117
118
119
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
120
    ):
121
        self.eos_token_id = eos_token_id
122
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
123
124
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
125
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
126

127
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
128
129
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
130
131
            return True, "length"

132
133
134
135
        if last_token == self.eos_token_id:
            return True, "eos_token"

        self.current_output += last_output
136
        for stop_sequence_criteria in self.stop_sequence_criterias:
137
            if stop_sequence_criteria(self.current_output):
138
139
140
141
142
143
                return True, "stop_sequence"

        return False, None

    @classmethod
    def from_pb(
144
145
146
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
147
    ) -> "StoppingCriteria":
148
149
150
151
152
153
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
175
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
176
177
178
179
180
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


181
def weight_hub_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
182
183
184
    """Get the safetensors filenames on the hub"""
    api = HfApi()
    info = api.model_info(model_name)
185
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
186
187
188
    return filenames


189
def weight_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
190
    """Get the local safetensors filenames"""
191
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
192
193
194
195
196
197
198
    files = []
    for filename in filenames:
        cache_file = try_to_load_from_cache(model_name, filename=filename)
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
199
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
200
201
202
203
204
205
            )
        files.append(cache_file)

    return files


206
def download_weights(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
207
    """Download the safetensors files from the hub"""
208
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
209
210

    download_function = partial(
211
212
213
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
214
    )
215
216

    executor = ThreadPoolExecutor(max_workers=5)
217
218
219
220
    futures = [
        executor.submit(download_function, filename=filename) for filename in filenames
    ]
    files = [
221
222
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
223
    ]
224

Nicolas Patry's avatar
Nicolas Patry committed
225
    return files