main.py 1.48 KB
Newer Older
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1
import time
2

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
3
4
5
from model import Node2vecModel
from utils import load_graph, parse_arguments

6
7
from dgl.sampling import node2vec_random_walk

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
8
9
10
11
12
13
14
15
16

def time_randomwalk(graph, args):
    """
    Test cost time of random walk
    """

    start_time = time.time()

    # default setting for testing
17
    params = {"p": 0.25, "q": 4, "walk_length": 50}
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
18
19
20
21

    for i in range(args.runs):
        node2vec_random_walk(graph, graph.nodes(), **params)
    end_time = time.time()
22
23
24
25
26
27
    cost_time_avg = (end_time - start_time) / args.runs
    print(
        "Run dataset {} {} trials, mean run time: {:.3f}s".format(
            args.dataset, args.runs, cost_time_avg
        )
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
28
29
30
31
32
33


def train_node2vec(graph, eval_set, args):
    """
    Train node2vec model
    """
34
35
36
37
38
39
40
41
42
43
44
    trainer = Node2vecModel(
        graph,
        embedding_dim=args.embedding_dim,
        walk_length=args.walk_length,
        p=args.p,
        q=args.q,
        num_walks=args.num_walks,
        eval_set=eval_set,
        eval_steps=1,
        device=args.device,
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
45

46
47
48
    trainer.train(
        epochs=args.epochs, batch_size=args.batch_size, learning_rate=0.01
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
49
50


51
if __name__ == "__main__":
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
52
53
54
55

    args = parse_arguments()
    graph, eval_set = load_graph(args.dataset)

56
    if args.task == "train":
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
57
58
        print("Perform training node2vec model")
        train_node2vec(graph, eval_set, args)
59
    elif args.task == "time":
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
60
61
62
        print("Timing random walks")
        time_randomwalk(graph, args)
    else:
63
        raise ValueError("Task type error!")