utils.py 6.17 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import re
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
14
from typing import List, Optional, Tuple
15
from transformers import PreTrainedTokenizerBase
16
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17
18
19
20
21
22
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)

23
24
from text_generation.pb import generate_pb2

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

class Sampling:
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
    def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True

        self.warpers = warpers
        self.choice = Sampling() if sampling else Greedy()

    def __call__(self, input_ids, scores):
OlivierDehaene's avatar
OlivierDehaene committed
59
        # Warp logits
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
60
        scores = self.warpers(input_ids, scores)
OlivierDehaene's avatar
OlivierDehaene committed
61
62
63
        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)
        # Choose tokens
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
64
        next_ids = self.choice(scores)
OlivierDehaene's avatar
OlivierDehaene committed
65
        return next_ids, logprobs
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
66

67
    @classmethod
OlivierDehaene's avatar
OlivierDehaene committed
68
    def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
69
70
71
72
73
74
75
76
77
        return NextTokenChooser(
            temperature=pb.temperature,
            top_k=pb.top_k,
            top_p=pb.top_p,
            do_sample=pb.do_sample,
        )


class StopSequenceCriteria:
78
79
80
81
82
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
83
84
85
            return True
        return False

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
86
87

class StoppingCriteria:
88
    def __init__(
89
90
91
92
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
        max_new_tokens=20,
93
    ):
94
        self.eos_token_id = eos_token_id
95
        self.stop_sequence_criterias = stop_sequence_criterias
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
96
97
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
98
        self.current_output = ""
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
99

100
    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
101
102
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
103
104
            return True, "length"

105
106
107
108
        if last_token == self.eos_token_id:
            return True, "eos_token"

        self.current_output += last_output
109
        for stop_sequence_criteria in self.stop_sequence_criterias:
110
            if stop_sequence_criteria(self.current_output):
111
112
113
114
115
116
                return True, "stop_sequence"

        return False, None

    @classmethod
    def from_pb(
117
118
119
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
120
    ) -> "StoppingCriteria":
121
122
123
124
125
126
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
            tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
        )
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
148
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
152
153
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


154
def weight_hub_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
155
156
157
    """Get the safetensors filenames on the hub"""
    api = HfApi()
    info = api.model_info(model_name)
158
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
159
160
161
    return filenames


162
def weight_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
163
    """Get the local safetensors filenames"""
164
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
165
166
167
168
169
170
171
    files = []
    for filename in filenames:
        cache_file = try_to_load_from_cache(model_name, filename=filename)
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
172
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
173
174
175
176
177
178
            )
        files.append(cache_file)

    return files


179
def download_weights(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
180
    """Download the safetensors files from the hub"""
181
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
182
183

    download_function = partial(
184
185
186
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
187
    )
188
189

    executor = ThreadPoolExecutor(max_workers=5)
190
191
192
193
    futures = [
        executor.submit(download_function, filename=filename) for filename in filenames
    ]
    files = [
194
195
        future.result()
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
196
    ]
197

Nicolas Patry's avatar
Nicolas Patry committed
198
    return files