label_utils.py 3.32 KB
Newer Older
1
2
from collections import defaultdict

3
4
5
6
import numpy as np
import torch


7
def remove_unseen_classes_from_training(train_mask, labels, removed_class):
8
9
10
11
    """Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
    Input: train_mask, labels, removed_class
    Output: train_mask_zs: the bool list only containing seen classes
    """
12
13
    train_mask_zs = train_mask.clone()
    for i in range(train_mask_zs.numel()):
14
15
        if train_mask_zs[i] == 1 and (labels[i].item() in removed_class):
            train_mask_zs[i] = 0
16
    return train_mask_zs
17
18


19
def get_class_set(labels):
20
21
22
23
    """Get the class set.
    Input: labels [l, [c1, c2, ..]]
    Output:the labeled class set dict_keys([k1, k2, ..])
    """
24
25
26
27
28
29
    mydict = {}
    for y in labels:
        for label in y:
            mydict[int(label)] = 1
    return mydict.keys()

30

31
def get_label_attributes(train_mask_zs, nodeids, labellist, features):
32
33
34
35
36
    """Get the class-center (semanic knowledge) of each seen class.
    Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
    Input: train_mask_zs, nodeids, labellist, features
    Output: label_attribute{}: label -> average_labeled_node_features (class centers)
    """
37
38
39
40
41
42
43
44
45
46
47
48
49
    _, feat_num = features.shape
    labels = get_class_set(labellist)
    label_attribute_nodes = defaultdict(list)
    for nodeid, labels in zip(nodeids, labellist):
        for label in labels:
            label_attribute_nodes[int(label)].append(int(nodeid))
    label_attribute = {}
    for label in label_attribute_nodes.keys():
        nodes = label_attribute_nodes[int(label)]
        selected_features = features[nodes, :]
        label_attribute[int(label)] = np.mean(selected_features, axis=0)
    return label_attribute

50

51
def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
52
53
54
55
56
57
    """Replace the original labels by their class-centers.
    For each label c in the training set, the following operations will be performed:
    Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
    Input: train_mask_zs, labels, features
    Output: Y_{semantic} [l, ft]: tensor
    """
58
59
60
61
62
63
64
65
66
    X = torch.LongTensor(range(features.shape[0]))
    nodeids = []
    labellist = []
    for i in X[train_mask_zs].numpy().tolist():
        nodeids.append(str(i))
    for i in labels[train_mask_zs].cpu().numpy().tolist():
        labellist.append([str(i)])

    # 1. get the semantic knowledge (class centers) of all seen classes
67
68
69
70
71
72
73
    label_attribute = get_label_attributes(
        train_mask_zs=train_mask_zs,
        nodeids=nodeids,
        labellist=labellist,
        features=features.cpu().numpy(),
    )

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    # 2. replace original labels by their class centers (semantic knowledge)
    res = np.zeros([len(nodeids), features.shape[1]])
    for i, labels in enumerate(labellist):
        # support mutiple labels
        c = len(labels)
        temp = np.zeros([c, features.shape[1]])
        for ii, label in enumerate(labels):
            temp[ii, :] = label_attribute[int(label)]
        temp = np.mean(temp, axis=0)
        res[i, :] = temp
    if cuda:
        res = torch.FloatTensor(res).cuda()
    else:
        res = torch.FloatTensor(res)
    return res