utils.py 4.52 KB
Newer Older
1
import concurrent
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
import os
3
import signal
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
import torch
import torch.distributed
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7

from datetime import timedelta
Nicolas Patry's avatar
Nicolas Patry committed
8

9
from concurrent.futures import ThreadPoolExecutor
Nicolas Patry's avatar
Nicolas Patry committed
10
11
12
13
from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
14
from transformers.generation.logits_process import (
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopPLogitsWarper,
    TopKLogitsWarper,
)


class Sampling:
    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class NextTokenChooser:
    def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True

        self.warpers = warpers
        self.choice = Sampling() if sampling else Greedy()

    def __call__(self, input_ids, scores):
        scores = self.warpers(input_ids, scores)
        next_ids = self.choice(scores)
        return next_ids.unsqueeze(-1)


class StoppingCriteria:
61
62
    def __init__(self, eos_token_id, max_new_tokens=20):
        self.eos_token_id = eos_token_id
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
63
64
65
66
67
68
69
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0

    def __call__(self, all_ids):
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
            return True
70
71
        if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id:
            return True
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        return False


def initialize_torch_distributed():
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    if torch.cuda.is_available():
        # initialized `torch.distributed`
        # Set the device id.
        assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
        device = rank % torch.cuda.device_count()
        torch.cuda.set_device(device)
        backend = "nccl"
    else:
        backend = "gloo"

    # Call the init process.
    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
Olivier Dehaene's avatar
Olivier Dehaene committed
94
        timeout=timedelta(seconds=60),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
95
96
97
98
99
    )

    return torch.distributed.distributed_c10d._get_default_group(), rank, world_size


100
def weight_hub_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
    """Get the safetensors filenames on the hub"""
    api = HfApi()
    info = api.model_info(model_name)
104
    filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
Nicolas Patry's avatar
Nicolas Patry committed
105
106
107
    return filenames


108
def weight_files(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
109
    """Get the local safetensors filenames"""
110
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
111
112
113
114
115
116
117
    files = []
    for filename in filenames:
        cache_file = try_to_load_from_cache(model_name, filename=filename)
        if cache_file is None:
            raise LocalEntryNotFoundError(
                f"File {filename} of model {model_name} not found in "
                f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
118
                f"Please run `text-generation-server download-weights {model_name}` first."
Nicolas Patry's avatar
Nicolas Patry committed
119
120
121
122
123
124
            )
        files.append(cache_file)

    return files


125
def download_weights(model_name, extension=".safetensors"):
Nicolas Patry's avatar
Nicolas Patry committed
126
    """Download the safetensors files from the hub"""
127
    filenames = weight_hub_files(model_name, extension)
Nicolas Patry's avatar
Nicolas Patry committed
128
129

    download_function = partial(
130
131
132
        hf_hub_download,
        repo_id=model_name,
        local_files_only=False,
Nicolas Patry's avatar
Nicolas Patry committed
133
    )
134
135

    executor = ThreadPoolExecutor(max_workers=5)
136
137
138
139
140
141
142
    futures = [
        executor.submit(download_function, filename=filename) for filename in filenames
    ]
    files = [
        file
        for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
    ]
143

Nicolas Patry's avatar
Nicolas Patry committed
144
    return files