"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ff263947ad623800bfe13297ac308073f01e8dea"
Unverified Commit 2295c218 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Benchmark] Add benchmark for GraphConv and HeteroGraphConv (#2999)



* fix

* lint

* Revert "lint"

This reverts commit 263f913f02a9ece8491e9a9f812e50892da899ba.
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 789b03f4
import time
import dgl
import torch
import numpy as np
import dgl.function as fn
from dgl.nn.pytorch import SAGEConv
import torch.nn as nn
import torch.nn.functional as F
from .. import utils
@utils.benchmark('time')
@utils.parametrize('graph_name', ['pubmed','ogbn-arxiv'])
@utils.parametrize('feat_dim', [4, 32, 256])
@utils.parametrize('aggr_type', ['mean', 'gcn', 'pool'])
def track_time(graph_name, feat_dim, aggr_type):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name).to(device)
feat = torch.randn((graph.num_nodes(), feat_dim), device=device)
model = SAGEConv(feat_dim, feat_dim, aggr_type, activation=F.relu, bias=False).to(device)
# dry run
for i in range(3):
model(graph, feat)
# timing
with utils.Timer() as t:
for i in range(50):
model(graph, feat)
return t.elapsed_secs / 50
import time
import dgl
import torch
import numpy as np
import dgl.function as fn
from dgl.nn.pytorch import SAGEConv, HeteroGraphConv
import torch.nn as nn
import torch.nn.functional as F
from .. import utils
@utils.benchmark('time')
@utils.parametrize('feat_dim', [4, 32, 256])
@utils.parametrize('num_relations', [5, 50, 200])
def track_time(feat_dim, num_relations):
device = utils.get_bench_device()
dd = {}
nn_dict = {}
candidate_edges = [dgl.data.CoraGraphDataset(verbose=False)[0].edges(), dgl.data.PubmedGraphDataset(verbose=False)[
0].edges(), dgl.data.CiteseerGraphDataset(verbose=False)[0].edges()]
for i in range(num_relations):
dd[('n1', 'e_{}'.format(i), 'n2')] = candidate_edges[i %
len(candidate_edges)]
nn_dict['e_{}'.format(i)] = SAGEConv(feat_dim, feat_dim, 'mean', activation=F.relu)
# dry run
feat_dict = {}
graph = dgl.heterograph(dd)
for i in range(num_relations):
etype = 'e_{}'.format(i)
feat_dict[etype] = torch.randn((graph[etype].num_nodes(), feat_dim), device=device)
conv = HeteroGraphConv(nn_dict).to(device)
# dry run
for i in range(3):
conv(graph, feat_dict)
# timing
with utils.Timer() as t:
for i in range(50):
conv(graph, feat_dict)
return t.elapsed_secs / 50
...@@ -64,10 +64,14 @@ def thread_wrapped_func(func): ...@@ -64,10 +64,14 @@ def thread_wrapped_func(func):
return decorated_function return decorated_function
def get_graph(name, format): def get_graph(name, format = None):
# global GRAPH_CACHE # global GRAPH_CACHE
# if name in GRAPH_CACHE: # if name in GRAPH_CACHE:
# return GRAPH_CACHE[name].to(format) # return GRAPH_CACHE[name].to(format)
if isinstance(format, str):
format = [format] # didn't specify format
if format is None:
format = ['csc', 'csr', 'coo']
g = None g = None
if name == 'cora': if name == 'cora':
g = dgl.data.CoraGraphDataset(verbose=False)[0] g = dgl.data.CoraGraphDataset(verbose=False)[0]
...@@ -79,7 +83,7 @@ def get_graph(name, format): ...@@ -79,7 +83,7 @@ def get_graph(name, format):
g_list, _ = dgl.load_graphs(bin_path) g_list, _ = dgl.load_graphs(bin_path)
g = g_list[0] g = g_list[0]
else: else:
g = get_livejournal().formats([format]) g = get_livejournal().formats(format)
dgl.save_graphs(bin_path, [g]) dgl.save_graphs(bin_path, [g])
elif name == "friendster": elif name == "friendster":
bin_path = "/tmp/dataset/friendster/friendster_{}.bin".format(format) bin_path = "/tmp/dataset/friendster/friendster_{}.bin".format(format)
...@@ -87,7 +91,7 @@ def get_graph(name, format): ...@@ -87,7 +91,7 @@ def get_graph(name, format):
g_list, _ = dgl.load_graphs(bin_path) g_list, _ = dgl.load_graphs(bin_path)
g = g_list[0] g = g_list[0]
else: else:
g = get_friendster().formats([format]) g = get_friendster().formats(format)
dgl.save_graphs(bin_path, [g]) dgl.save_graphs(bin_path, [g])
elif name == "reddit": elif name == "reddit":
bin_path = "/tmp/dataset/reddit/reddit_{}.bin".format(format) bin_path = "/tmp/dataset/reddit/reddit_{}.bin".format(format)
...@@ -102,7 +106,7 @@ def get_graph(name, format): ...@@ -102,7 +106,7 @@ def get_graph(name, format):
else: else:
raise Exception("Unknown dataset") raise Exception("Unknown dataset")
# GRAPH_CACHE[name] = g # GRAPH_CACHE[name] = g
g = g.formats([format]) g = g.formats(format)
return g return g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment