data_utils.py 1.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch

def one_hotify(labels, pad=-1):
        '''
        cast label to one hot vector
        '''
        num_instances = len(labels)
        if pad <= 0:
            dim_embedding = np.max(labels) + 1 #zero-indexed assumed
        else:
            assert pad > 0, "result_dim for padding one hot embedding not set!"
            dim_embedding = pad + 1
        embeddings = np.zeros((num_instances, dim_embedding))
        embeddings[np.arange(num_instances), labels] = 1

        return embeddings


def pre_process(dataset, prog_args):
        """
        diffpool specific data partition, pre-process and shuffling
        """
        if prog_args.data_mode != "default":
            print("overwrite node attributes with DiffPool's preprocess setting")
            if prog_args.data_mode == 'id':
                for g, _ in dataset:
                    id_list = np.arange(g.number_of_nodes())
                    g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node)
                    
            elif prog_args.data_mode == 'deg-num':
                for g, _ in dataset:
                    g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1)
                
            elif prog_args.data_mode == 'deg':
                for g in dataset:
                    degs = list(g.in_degrees())
                    degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
                    g.ndata['feat'] = degs_one_hot