nodepred-ns.jinja-py 6.2 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import AsNodePredDataset

{{ data_import_code }}

{{ model_code }}

{% if user_cfg.early_stop %}
class EarlyStopping:
    def __init__(self,
                 patience: int = -1,
                 checkpoint_path: str = 'checkpoint.pt'):
        self.patience = patience
        self.checkpoint_path = checkpoint_path
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def step(self, acc, model):
        score = acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
        '''Save model when validation loss decreases.'''
        torch.save(model.state_dict(), self.checkpoint_path)

    def load_checkpoint(self, model):
        model.load_state_dict(torch.load(self.checkpoint_path))
{% endif %}


def load_subtensor(nfeat, labels, seeds, input_nodes, device):
    """
    Extracts features and labels for a subset of nodes
    """
    batch_inputs = nfeat[input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels

def evaluate(model, g, nfeat, labels, val_nid, eval_device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    device : The GPU device to evaluate on.
    """
    model.eval()
    eval_model = model.to(eval_device)
    g = g.to(eval_device)
    nfeat = nfeat.to(eval_device)
    with torch.no_grad():
        y = eval_model(g, nfeat)
    model.train()
    return accuracy(y[val_nid], labels[val_nid].to(y.device))

def accuracy(logits, labels):
    _, indices = torch.max(logits, dim=1)
    correct = torch.sum(indices == labels)
    return correct.item() * 1.0 / len(labels)

def train(cfg, pipeline_cfg, device, data, model, optimizer, loss_fcn):
    g = data[0]  # Only train on the first graph
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
    train_g = val_g = test_g = g

    train_nfeat = val_nfeat = test_nfeat = train_g.ndata['feat']
    train_labels = val_labels = test_labels = train_g.ndata['label']

    train_nid = torch.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
    val_nid = torch.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
    test_nid = torch.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]

    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in pipeline_cfg["sampler"]["fan_out"]])
    dataloader = dgl.dataloading.NodeDataLoader(
        train_g,
        train_nid,
        sampler,
        device=device,
        batch_size=pipeline_cfg["sampler"]["batch_size"],
        shuffle=True,
        drop_last=False,
        num_workers=pipeline_cfg["sampler"]["num_workers"])

    {% if user_cfg.early_stop %}
    stopper = EarlyStopping(pipeline_cfg['patience'], pipeline_cfg['checkpoint_path'])    
    {% endif %}
    val_acc = 0.
    for epoch in range(pipeline_cfg['num_epochs']):
        model.train()
        model = model.to(device)
        for step, (input_nodes, seeds, subgs) in enumerate(dataloader):
            # Load the input features as well as output labels
            batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
                                                        seeds, input_nodes, device)
            subgs = [subg.int().to(device) for subg in subgs]
            batch_pred = model.forward_block(subgs, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_acc = accuracy(batch_pred, batch_labels)                
            print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | TrainAcc {:.4f}".                
                format(epoch, step, loss.item(), train_acc))

        if epoch % pipeline_cfg["eval_period"] == 0 and epoch != 0:
            val_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, cfg["eval_device"])
            print('Eval Acc {:.4f}'.format(val_acc))
        {% if user_cfg.early_stop %}
        if stopper.step(val_acc, model):            
            break
        {% endif %}

    {% if user_cfg.early_stop %}
    stopper.load_checkpoint(model)
    {% endif %}
    model.eval()
    with torch.no_grad():
        test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, cfg["eval_device"])
    return test_acc

def main():
    {{ user_cfg_str }}
    device = cfg['device']
    pipeline_cfg = cfg["general_pipeline"]
    # load data
    data = AsNodePredDataset({{data_initialize_code}})
    # create model
    model_cfg = cfg["model"]
    cfg["model"]["data_info"] = {
        "in_size": model_cfg['embed_size'] if model_cfg['embed_size'] > 0 else data[0].ndata['feat'].shape[1],
        "out_size": data.num_classes,
        "num_nodes": data[0].num_nodes()
    }
    model = {{ model_class_name }}(**cfg["model"])
    model = model.to(device)
    loss = torch.nn.{{ user_cfg.general_pipeline.loss }}()
    optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"])
    # train
    test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
    return test_acc

if __name__ == '__main__':
    all_acc = []
    num_runs = {{ user_cfg.general_pipeline.num_runs }}
    for run in range(num_runs):
        print(f'Run experiment #{run}')
        test_acc = main()
        print("Test Accuracy {:.4f}".format(test_acc))
        all_acc.append(test_acc)
    avg_acc = np.round(np.mean(all_acc), 6)
    std_acc = np.round(np.std(all_acc), 6)
    print(f'Accuracy across {num_runs} runs: {avg_acc} ± {std_acc}')