"vscode:/vscode.git/clone" did not exist on "2e7ab862e3d62b68d553bc74ea1b86e7ddd93401"
utils.py 7.71 KB
Newer Older
1
import datetime
limm's avatar
limm committed
2
3
import errno
import os
4
import time
limm's avatar
limm committed
5
6
from collections import defaultdict, deque

7
8
9
10
import torch
import torch.distributed as dist


limm's avatar
limm committed
11
class SmoothedValue:
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
limm's avatar
limm committed
33
        t = reduce_across_processes([self.count, self.total])
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
limm's avatar
limm committed
62
63
            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
        )
64
65


limm's avatar
limm committed
66
class MetricLogger:
67
68
69
70
71
72
73
74
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
limm's avatar
limm committed
75
76
77
78
            if not isinstance(v, (float, int)):
                raise TypeError(
                    f"This method expects the value of the input arguments to be of type float or int, instead  got {type(v)}"
                )
79
80
81
82
83
84
85
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
limm's avatar
limm committed
86
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87
88
89
90

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
limm's avatar
limm committed
91
            loss_str.append(f"{name}: {str(meter)}")
92
93
94
95
96
97
98
99
100
101
102
103
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
limm's avatar
limm committed
104
            header = ""
105
106
        start_time = time.time()
        end = time.time()
limm's avatar
limm committed
107
108
109
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110
        if torch.cuda.is_available():
limm's avatar
limm committed
111
112
113
114
115
116
117
118
119
120
121
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                    "max mem: {memory:.0f}",
                ]
            )
122
        else:
limm's avatar
limm committed
123
124
125
            log_msg = self.delimiter.join(
                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
            )
126
127
128
129
130
131
132
133
134
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
limm's avatar
limm committed
135
136
137
138
139
140
141
142
143
144
145
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
146
                else:
limm's avatar
limm committed
147
148
149
150
151
                    print(
                        log_msg.format(
                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
                        )
                    )
152
153
154
155
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
limm's avatar
limm committed
156
        print(f"{header} Total time: {total_time_str}")
157
158
159
160


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
limm's avatar
limm committed
161
    with torch.inference_mode():
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
limm's avatar
limm committed
189

190
191
192
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
limm's avatar
limm committed
193
        force = kwargs.pop("force", False)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
limm's avatar
limm committed
230
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
231
        args.rank = int(os.environ["RANK"])
limm's avatar
limm committed
232
233
234
235
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
236
237
238
239
        args.gpu = args.rank % torch.cuda.device_count()
    elif hasattr(args, "rank"):
        pass
    else:
limm's avatar
limm committed
240
        print("Not using distributed mode")
241
242
243
244
245
246
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
limm's avatar
limm committed
247
248
249
250
251
252
    args.dist_backend = "nccl"
    print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
    torch.distributed.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    torch.distributed.barrier()
253
    setup_for_distributed(args.rank == 0)
limm's avatar
limm committed
254
255
256
257
258
259
260
261
262
263
264


def reduce_across_processes(val, op=dist.ReduceOp.SUM):
    if not is_dist_avail_and_initialized():
        # nothing to sync, but we still convert to tensor for consistency with the distributed case.
        return torch.tensor(val)

    t = torch.tensor(val, device="cuda")
    dist.barrier()
    dist.all_reduce(t, op=op)
    return t