"docs/source/en/index.mdx" did not exist on "b8894f181dd10fb93e8ffb86ad3cfdda7f4a3380"
bench_nn_heterographconv.py 1.39 KB
Newer Older
1
import time
2

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
5
import dgl
import dgl.function as fn

6
import numpy as np
7
import torch
8
9
import torch.nn as nn
import torch.nn.functional as F
10
11
from dgl.nn.pytorch import HeteroGraphConv, SAGEConv

12
13
14
from .. import utils


15
16
17
@utils.benchmark("time")
@utils.parametrize("feat_dim", [4, 32, 256])
@utils.parametrize("num_relations", [5, 50, 200])
18
def track_time(feat_dim, num_relations):
19
    device = utils.get_bench_device()
20
21
    dd = {}
    nn_dict = {}
22
23
24
25
26
    candidate_edges = [
        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),
        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),
        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),
    ]
27
    for i in range(num_relations):
28
29
30
31
32
33
        dd[("n1", "e_{}".format(i), "n2")] = candidate_edges[
            i % len(candidate_edges)
        ]
        nn_dict["e_{}".format(i)] = SAGEConv(
            feat_dim, feat_dim, "mean", activation=F.relu
        )
34
35
36
37
38

    # dry run
    feat_dict = {}
    graph = dgl.heterograph(dd)
    for i in range(num_relations):
39
40
41
42
        etype = "e_{}".format(i)
        feat_dict[etype] = torch.randn(
            (graph[etype].num_nodes(), feat_dim), device=device
        )
43
44
45
46
47
48
49
50
51
52
53
54

    conv = HeteroGraphConv(nn_dict).to(device)

    # dry run
    for i in range(3):
        conv(graph, feat_dict)
    # timing
    with utils.Timer() as t:
        for i in range(50):
            conv(graph, feat_dict)

    return t.elapsed_secs / 50