misc.py 2.4 KB
Newer Older
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import json
8
import os
9
10
import pickle
import random
11
12
import time

13
14
import numpy as np

15

16
class TextColors:
17
18
19
20
21
22
23
24
25
26
27
28
    HEADER = "\033[35m"
    OKBLUE = "\033[34m"
    OKGREEN = "\033[32m"
    WARNING = "\033[33m"
    FATAL = "\033[31m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"


class Timer:
    def __init__(self, name="task", verbose=True):
29
30
31
32
33
34
35
36
37
        self.name = name
        self.verbose = verbose

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.verbose:
38
39
40
41
42
            print(
                "[Time] {} consumes {:.4f} s".format(
                    self.name, time.time() - self.start
                )
            )
43
44
        return exc_type is None

45

46
47
def set_random_seed(seed, cuda=False):
    import torch
48

49
50
51
52
53
54
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)

55

56
57
58
59
def l2norm(vec):
    vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1)
    return vec

60

61
62
63
64
65
def is_l2norm(features, size):
    rand_i = random.choice(range(size))
    norm_ = np.dot(features[rand_i, :], features[rand_i, :])
    return abs(norm_ - 1) < 1e-6

66

67
68
69
def is_spmat_eq(a, b):
    return (a != b).nnz == 0

70

71
72
73
74
75
76
def aggregate(features, adj, times):
    dtype = features.dtype
    for i in range(times):
        features = adj * features
    return features.astype(dtype)

77
78
79

def mkdir_if_no_exists(path, subdirs=[""], is_folder=False):
    if path == "":
80
81
        return
    for sd in subdirs:
82
        if sd != "" or is_folder:
83
84
85
86
87
88
            d = os.path.dirname(os.path.join(path, sd))
        else:
            d = os.path.dirname(path)
        if not os.path.exists(d):
            os.makedirs(d)

89
90
91
92
93
94
95
96
97

def stop_iterating(
    current_l,
    total_l,
    early_stop,
    num_edges_add_this_level,
    num_edges_add_last_level,
    knn_k,
):
98
99
100
101
102
103
104
    # Stopping rule 1: run all levels
    if current_l == total_l - 1:
        return True
    # Stopping rule 2: no new edges
    if num_edges_add_this_level == 0:
        return True
    # Stopping rule 3: early stopping, two levels start to produce similar numbers of edges
105
106
107
108
109
    if (
        early_stop
        and float(num_edges_add_last_level) / num_edges_add_this_level
        < knn_k - 1
    ):
110
111
        return True
    return False