label_utils.py 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import numpy as np
from collections import defaultdict

def remove_unseen_classes_from_training(train_mask, labels, removed_class):
    ''' Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
        Input: train_mask, labels, removed_class
        Output: train_mask_zs: the bool list only containing seen classes
    '''
    train_mask_zs = train_mask.clone()
    for i in range(train_mask_zs.numel()):
        if train_mask_zs[i]==1 and (labels[i].item() in removed_class):
            train_mask_zs[i]=0
    return train_mask_zs
    
def get_class_set(labels):
    ''' Get the class set.
        Input: labels [l, [c1, c2, ..]]
        Output:the labeled class set dict_keys([k1, k2, ..])
    '''
    mydict = {}
    for y in labels:
        for label in y:
            mydict[int(label)] = 1
    return mydict.keys()

def get_label_attributes(train_mask_zs, nodeids, labellist, features):
    ''' Get the class-center (semanic knowledge) of each seen class.
        Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
        Input: train_mask_zs, nodeids, labellist, features
        Output: label_attribute{}: label -> average_labeled_node_features (class centers)
    '''
    _, feat_num = features.shape
    labels = get_class_set(labellist)
    label_attribute_nodes = defaultdict(list)
    for nodeid, labels in zip(nodeids, labellist):
        for label in labels:
            label_attribute_nodes[int(label)].append(int(nodeid))
    label_attribute = {}
    for label in label_attribute_nodes.keys():
        nodes = label_attribute_nodes[int(label)]
        selected_features = features[nodes, :]
        label_attribute[int(label)] = np.mean(selected_features, axis=0)
    return label_attribute

def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
    ''' Replace the original labels by their class-centers.
        For each label c in the training set, the following operations will be performed:
        Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
        Input: train_mask_zs, labels, features
        Output: Y_{semantic} [l, ft]: tensor
    '''
    X = torch.LongTensor(range(features.shape[0]))
    nodeids = []
    labellist = []
    for i in X[train_mask_zs].numpy().tolist():
        nodeids.append(str(i))
    for i in labels[train_mask_zs].cpu().numpy().tolist():
        labellist.append([str(i)])

    # 1. get the semantic knowledge (class centers) of all seen classes
    label_attribute = get_label_attributes(train_mask_zs=train_mask_zs, nodeids=nodeids, labellist=labellist, features=features.cpu().numpy())
    
    # 2. replace original labels by their class centers (semantic knowledge)
    res = np.zeros([len(nodeids), features.shape[1]])
    for i, labels in enumerate(labellist):
        # support mutiple labels
        c = len(labels)
        temp = np.zeros([c, features.shape[1]])
        for ii, label in enumerate(labels):
            temp[ii, :] = label_attribute[int(label)]
        temp = np.mean(temp, axis=0)
        res[i, :] = temp
    if cuda:
        res = torch.FloatTensor(res).cuda()
    else:
        res = torch.FloatTensor(res)
    return res