Unverified Commit d78a3a4b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4640)



* auto fix

* add more

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 23d09057
import time import time
import dgl
import torch
import numpy as np import numpy as np
import dgl.function as fn import torch
from dgl.nn.pytorch import SAGEConv, HeteroGraphConv
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import HeteroGraphConv, SAGEConv
from .. import utils from .. import utils
@utils.benchmark('time') @utils.benchmark("time")
@utils.parametrize('feat_dim', [4, 32, 256]) @utils.parametrize("feat_dim", [4, 32, 256])
@utils.parametrize('num_relations', [5, 50, 200]) @utils.parametrize("num_relations", [5, 50, 200])
def track_time(feat_dim, num_relations): def track_time(feat_dim, num_relations):
device = utils.get_bench_device() device = utils.get_bench_device()
dd = {} dd = {}
nn_dict = {} nn_dict = {}
candidate_edges = [dgl.data.CoraGraphDataset(verbose=False)[0].edges(), dgl.data.PubmedGraphDataset(verbose=False)[ candidate_edges = [
0].edges(), dgl.data.CiteseerGraphDataset(verbose=False)[0].edges()] dgl.data.CoraGraphDataset(verbose=False)[0].edges(),
dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),
dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),
]
for i in range(num_relations): for i in range(num_relations):
dd[('n1', 'e_{}'.format(i), 'n2')] = candidate_edges[i % dd[("n1", "e_{}".format(i), "n2")] = candidate_edges[
len(candidate_edges)] i % len(candidate_edges)
nn_dict['e_{}'.format(i)] = SAGEConv(feat_dim, feat_dim, 'mean', activation=F.relu) ]
nn_dict["e_{}".format(i)] = SAGEConv(
feat_dim, feat_dim, "mean", activation=F.relu
)
# dry run # dry run
feat_dict = {} feat_dict = {}
graph = dgl.heterograph(dd) graph = dgl.heterograph(dd)
for i in range(num_relations): for i in range(num_relations):
etype = 'e_{}'.format(i) etype = "e_{}".format(i)
feat_dict[etype] = torch.randn((graph[etype].num_nodes(), feat_dim), device=device) feat_dict[etype] = torch.randn(
(graph[etype].num_nodes(), feat_dim), device=device
)
conv = HeteroGraphConv(nn_dict).to(device) conv = HeteroGraphConv(nn_dict).to(device)
......
import time import time
import dgl
import torch
import numpy as np import numpy as np
import dgl.function as fn import torch
import dgl
import dgl.function as fn
from .. import utils from .. import utils
@utils.benchmark('time') @utils.benchmark("time")
@utils.parametrize('graph_name', ['livejournal', 'reddit']) @utils.parametrize("graph_name", ["livejournal", "reddit"])
@utils.parametrize('format', ['coo', 'csc']) @utils.parametrize("format", ["coo", "csc"])
@utils.parametrize('seed_nodes_num', [200, 5000, 20000]) @utils.parametrize("seed_nodes_num", [200, 5000, 20000])
def track_time(graph_name, format, seed_nodes_num): def track_time(graph_name, format, seed_nodes_num):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format) graph = utils.get_graph(graph_name, format)
......
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
def _random_walk(g, seeds, length): def _random_walk(g, seeds, length):
return dgl.sampling.random_walk(g, seeds, length=length) return dgl.sampling.random_walk(g, seeds, length=length)
def _node2vec(g, seeds, length): def _node2vec(g, seeds, length):
return dgl.sampling.node2vec_random_walk(g, seeds, 1, 1, length) return dgl.sampling.node2vec_random_walk(g, seeds, 1, 1, length)
@utils.skip_if_gpu() @utils.skip_if_gpu()
@utils.benchmark('time') @utils.benchmark("time")
@utils.parametrize('graph_name', ['cora', 'livejournal', 'friendster']) @utils.parametrize("graph_name", ["cora", "livejournal", "friendster"])
@utils.parametrize('num_seeds', [10, 100, 1000]) @utils.parametrize("num_seeds", [10, 100, 1000])
@utils.parametrize('length', [2, 5, 10, 20]) @utils.parametrize("length", [2, 5, 10, 20])
@utils.parametrize('algorithm', ['_random_walk', '_node2vec']) @utils.parametrize("algorithm", ["_random_walk", "_node2vec"])
def track_time(graph_name, num_seeds, length, algorithm): def track_time(graph_name, num_seeds, length, algorithm):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, 'csr') graph = utils.get_graph(graph_name, "csr")
seeds = torch.randint(0, graph.num_nodes(), (num_seeds,)) seeds = torch.randint(0, graph.num_nodes(), (num_seeds,))
print(graph_name, num_seeds, length) print(graph_name, num_seeds, length)
alg = globals()[algorithm] alg = globals()[algorithm]
......
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
@utils.benchmark('time') @utils.benchmark("time")
@utils.parametrize('batch_size', [4, 256, 1024]) @utils.parametrize("batch_size", [4, 256, 1024])
@utils.parametrize('feat_size', [16, 128, 512]) @utils.parametrize("feat_size", [16, 128, 512])
@utils.parametrize('readout_op', ['sum', 'max', 'min', 'mean']) @utils.parametrize("readout_op", ["sum", "max", "min", "mean"])
@utils.parametrize('type', ['edge', 'node']) @utils.parametrize("type", ["edge", "node"])
def track_time(batch_size, feat_size, readout_op, type): def track_time(batch_size, feat_size, readout_op, type):
device = utils.get_bench_device() device = utils.get_bench_device()
ds = dgl.data.QM7bDataset() ds = dgl.data.QM7bDataset()
...@@ -17,20 +19,20 @@ def track_time(batch_size, feat_size, readout_op, type): ...@@ -17,20 +19,20 @@ def track_time(batch_size, feat_size, readout_op, type):
graphs = ds[0:batch_size][0] graphs = ds[0:batch_size][0]
g = dgl.batch(graphs).to(device) g = dgl.batch(graphs).to(device)
if type == 'node': if type == "node":
g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device) g.ndata["h"] = torch.randn((g.num_nodes(), feat_size), device=device)
for i in range(10): for i in range(10):
out = dgl.readout_nodes(g, 'h', op=readout_op) out = dgl.readout_nodes(g, "h", op=readout_op)
with utils.Timer() as t: with utils.Timer() as t:
for i in range(50): for i in range(50):
out = dgl.readout_nodes(g, 'h', op=readout_op) out = dgl.readout_nodes(g, "h", op=readout_op)
elif type == 'edge': elif type == "edge":
g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device) g.edata["h"] = torch.randn((g.num_edges(), feat_size), device=device)
for i in range(10): for i in range(10):
out = dgl.readout_edges(g, 'h', op=readout_op) out = dgl.readout_edges(g, "h", op=readout_op)
with utils.Timer() as t: with utils.Timer() as t:
for i in range(50): for i in range(50):
out = dgl.readout_edges(g, 'h', op=readout_op) out = dgl.readout_edges(g, "h", op=readout_op)
else: else:
raise Exception("Unknown type") raise Exception("Unknown type")
......
import time import time
import dgl
import torch
import numpy as np import numpy as np
import torch
import dgl
from .. import utils from .. import utils
@utils.benchmark('time', timeout=1200) @utils.benchmark("time", timeout=1200)
@utils.parametrize_cpu('graph_name', ['cora', 'livejournal', 'friendster']) @utils.parametrize_cpu("graph_name", ["cora", "livejournal", "friendster"])
@utils.parametrize_gpu('graph_name', ['cora', 'livejournal']) @utils.parametrize_gpu("graph_name", ["cora", "livejournal"])
@utils.parametrize('format', ['coo', 'csc', 'csr']) @utils.parametrize("format", ["coo", "csc", "csr"])
def track_time(graph_name, format): def track_time(graph_name, format):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format) graph = utils.get_graph(graph_name, format)
......
import time import time
import dgl
import torch
import numpy as np import numpy as np
import dgl.function as fn import torch
import dgl
import dgl.function as fn
from .. import utils from .. import utils
@utils.benchmark('time')
@utils.parametrize_cpu('graph_name', ['livejournal', 'reddit']) @utils.benchmark("time")
@utils.parametrize_gpu('graph_name', ['ogbn-arxiv', 'reddit']) @utils.parametrize_cpu("graph_name", ["livejournal", "reddit"])
@utils.parametrize('format', ['coo', 'csc']) @utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"])
@utils.parametrize('seed_nodes_num', [200, 5000, 20000]) @utils.parametrize("format", ["coo", "csc"])
@utils.parametrize('fanout', [5, 20, 40]) @utils.parametrize("seed_nodes_num", [200, 5000, 20000])
@utils.parametrize("fanout", [5, 20, 40])
def track_time(graph_name, format, seed_nodes_num, fanout): def track_time(graph_name, format, seed_nodes_num, fanout):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format) graph = utils.get_graph(graph_name, format)
edge_dir = 'in' edge_dir = "in"
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
# dry run # dry run
for i in range(3): for i in range(3):
dgl.sampling.sample_neighbors( dgl.sampling.sample_neighbors(
graph, seed_nodes, fanout, edge_dir=edge_dir) graph, seed_nodes, fanout, edge_dir=edge_dir
)
# timing # timing
with utils.Timer() as t: with utils.Timer() as t:
for i in range(50): for i in range(50):
dgl.sampling.sample_neighbors( dgl.sampling.sample_neighbors(
graph, seed_nodes, fanout, edge_dir=edge_dir) graph, seed_nodes, fanout, edge_dir=edge_dir
)
return t.elapsed_secs / 50 return t.elapsed_secs / 50
import time import time
import dgl
import torch
import numpy as np import numpy as np
import torch
import dgl
from .. import utils from .. import utils
@utils.skip_if_gpu() @utils.skip_if_gpu()
@utils.benchmark('time', timeout=1200) @utils.benchmark("time", timeout=1200)
@utils.parametrize('graph_name', ['reddit', "ogbn-products"]) @utils.parametrize("graph_name", ["reddit", "ogbn-products"])
@utils.parametrize('num_seed_nodes', [32, 256, 1024, 2048]) @utils.parametrize("num_seed_nodes", [32, 256, 1024, 2048])
@utils.parametrize('fanout', [5, 10, 20]) @utils.parametrize("fanout", [5, 10, 20])
def track_time(graph_name, num_seed_nodes, fanout): def track_time(graph_name, num_seed_nodes, fanout):
device = utils.get_bench_device() device = utils.get_bench_device()
data = utils.process_data(graph_name) data = utils.process_data(graph_name)
...@@ -22,7 +24,8 @@ def track_time(graph_name, num_seed_nodes, fanout): ...@@ -22,7 +24,8 @@ def track_time(graph_name, num_seed_nodes, fanout):
subg_list = [] subg_list = []
for i in range(10): for i in range(10):
seed_nodes = np.random.randint( seed_nodes = np.random.randint(
0, graph.num_nodes(), size=num_seed_nodes) 0, graph.num_nodes(), size=num_seed_nodes
)
subg = dgl.sampling.sample_neighbors(graph, seed_nodes, fanout) subg = dgl.sampling.sample_neighbors(graph, seed_nodes, fanout)
subg_list.append(subg) subg_list.append(subg)
......
import time import time
import dgl
import torch
import numpy as np import numpy as np
import torch
import dgl
import dgl.function as fn import dgl.function as fn
from .. import utils from .. import utils
@utils.benchmark('time', timeout=7200) @utils.benchmark("time", timeout=7200)
@utils.parametrize('graph_name', ['ogbn-arxiv', 'pubmed']) @utils.parametrize("graph_name", ["ogbn-arxiv", "pubmed"])
@utils.parametrize('format', ['coo']) # only coo supports udf @utils.parametrize("format", ["coo"]) # only coo supports udf
@utils.parametrize('feat_size', [8, 32, 128, 512]) @utils.parametrize("feat_size", [8, 32, 128, 512])
@utils.parametrize('reduce_type', ['u->e', 'u+v']) @utils.parametrize("reduce_type", ["u->e", "u+v"])
def track_time(graph_name, format, feat_size, reduce_type): def track_time(graph_name, format, feat_size, reduce_type):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format) graph = utils.get_graph(graph_name, format)
graph = graph.to(device) graph = graph.to(device)
graph.ndata['h'] = torch.randn( graph.ndata["h"] = torch.randn(
(graph.num_nodes(), feat_size), device=device) (graph.num_nodes(), feat_size), device=device
)
reduce_udf_dict = { reduce_udf_dict = {
'u->e': lambda edges: {'x': edges.src['h']}, "u->e": lambda edges: {"x": edges.src["h"]},
'u+v': lambda edges: {'x': edges.src['h'] + edges.dst['h']}, "u+v": lambda edges: {"x": edges.src["h"] + edges.dst["h"]},
} }
# dry run # dry run
......
import time import time
import dgl
import torch
import numpy as np import numpy as np
import dgl.function as fn import torch
import dgl
import dgl.function as fn
from .. import utils from .. import utils
@utils.benchmark('time', timeout=600) @utils.benchmark("time", timeout=600)
@utils.parametrize('feat_size', [32, 128, 512]) @utils.parametrize("feat_size", [32, 128, 512])
@utils.parametrize('num_relations', [5, 50, 500]) @utils.parametrize("num_relations", [5, 50, 500])
@utils.parametrize('multi_reduce_type', ["sum", "stack"]) @utils.parametrize("multi_reduce_type", ["sum", "stack"])
def track_time(feat_size, num_relations, multi_reduce_type): def track_time(feat_size, num_relations, multi_reduce_type):
device = utils.get_bench_device() device = utils.get_bench_device()
dd = {} dd = {}
candidate_edges = [dgl.data.CoraGraphDataset(verbose=False)[0].edges(), dgl.data.PubmedGraphDataset(verbose=False)[ candidate_edges = [
0].edges(), dgl.data.CiteseerGraphDataset(verbose=False)[0].edges()] dgl.data.CoraGraphDataset(verbose=False)[0].edges(),
dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),
dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),
]
for i in range(num_relations): for i in range(num_relations):
dd[('n1', 'e_{}'.format(i), 'n2')] = candidate_edges[i % dd[("n1", "e_{}".format(i), "n2")] = candidate_edges[
len(candidate_edges)] i % len(candidate_edges)
]
graph = dgl.heterograph(dd) graph = dgl.heterograph(dd)
graph = graph.to(device) graph = graph.to(device)
graph.nodes['n1'].data['h'] = torch.randn( graph.nodes["n1"].data["h"] = torch.randn(
(graph.num_nodes('n1'), feat_size), device=device) (graph.num_nodes("n1"), feat_size), device=device
graph.nodes['n2'].data['h'] = torch.randn( )
(graph.num_nodes('n2'), feat_size), device=device) graph.nodes["n2"].data["h"] = torch.randn(
(graph.num_nodes("n2"), feat_size), device=device
)
# dry run # dry run
update_dict = {} update_dict = {}
for i in range(num_relations): for i in range(num_relations):
update_dict['e_{}'.format(i)] = ( update_dict["e_{}".format(i)] = (
lambda edges: {'x': edges.src['h']}, lambda nodes: {'h_new': torch.sum(nodes.mailbox['x'], dim=1)}) lambda edges: {"x": edges.src["h"]},
graph.multi_update_all( lambda nodes: {"h_new": torch.sum(nodes.mailbox["x"], dim=1)},
update_dict, )
multi_reduce_type) graph.multi_update_all(update_dict, multi_reduce_type)
# timing # timing
with utils.Timer() as t: with utils.Timer() as t:
for i in range(3): for i in range(3):
graph.multi_update_all( graph.multi_update_all(update_dict, multi_reduce_type)
update_dict,
multi_reduce_type)
return t.elapsed_secs / 3 return t.elapsed_secs / 3
import time import time
import dgl
import torch
import numpy as np import numpy as np
import torch
import dgl
import dgl.function as fn import dgl.function as fn
from .. import utils from .. import utils
@utils.benchmark('time', timeout=600) @utils.benchmark("time", timeout=600)
@utils.parametrize('graph_name', ['pubmed', 'ogbn-arxiv']) @utils.parametrize("graph_name", ["pubmed", "ogbn-arxiv"])
@utils.parametrize('format', ['coo']) # only coo supports udf @utils.parametrize("format", ["coo"]) # only coo supports udf
@utils.parametrize('feat_size', [8, 64, 512]) @utils.parametrize("feat_size", [8, 64, 512])
@utils.parametrize('msg_type', ['copy_u', 'u_mul_e']) @utils.parametrize("msg_type", ["copy_u", "u_mul_e"])
@utils.parametrize('reduce_type', ['sum', 'mean', 'max']) @utils.parametrize("reduce_type", ["sum", "mean", "max"])
def track_time(graph_name, format, feat_size, msg_type, reduce_type): def track_time(graph_name, format, feat_size, msg_type, reduce_type):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format) graph = utils.get_graph(graph_name, format)
graph = graph.to(device) graph = graph.to(device)
graph.ndata['h'] = torch.randn( graph.ndata["h"] = torch.randn(
(graph.num_nodes(), feat_size), device=device) (graph.num_nodes(), feat_size), device=device
graph.edata['e'] = torch.randn( )
(graph.num_edges(), 1), device=device) graph.edata["e"] = torch.randn((graph.num_edges(), 1), device=device)
msg_udf_dict = { msg_udf_dict = {
'copy_u': lambda edges: {'x': edges.src['h']}, "copy_u": lambda edges: {"x": edges.src["h"]},
'u_mul_e': lambda edges: {'x': edges.src['h']*edges.data['e']}, "u_mul_e": lambda edges: {"x": edges.src["h"] * edges.data["e"]},
} }
reduct_udf_dict = { reduct_udf_dict = {
'sum': lambda nodes: {'h_new': torch.sum(nodes.mailbox['x'], dim=1)}, "sum": lambda nodes: {"h_new": torch.sum(nodes.mailbox["x"], dim=1)},
'mean': lambda nodes: {'h_new': torch.mean(nodes.mailbox['x'], dim=1)}, "mean": lambda nodes: {"h_new": torch.mean(nodes.mailbox["x"], dim=1)},
'max': lambda nodes: {'h_new': torch.max(nodes.mailbox['x'], dim=1)[0]}, "max": lambda nodes: {"h_new": torch.max(nodes.mailbox["x"], dim=1)[0]},
} }
# dry run # dry run
...@@ -39,7 +41,8 @@ def track_time(graph_name, format, feat_size, msg_type, reduce_type): ...@@ -39,7 +41,8 @@ def track_time(graph_name, format, feat_size, msg_type, reduce_type):
# timing # timing
with utils.Timer() as t: with utils.Timer() as t:
for i in range(3): for i in range(3):
graph.update_all(msg_udf_dict[msg_type], graph.update_all(
reduct_udf_dict[reduce_type]) msg_udf_dict[msg_type], reduct_udf_dict[reduce_type]
)
return t.elapsed_secs / 3 return t.elapsed_secs / 3
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
@utils.benchmark('time')
@utils.parametrize('batch_size', [4, 32, 256, 1024]) @utils.benchmark("time")
@utils.parametrize("batch_size", [4, 32, 256, 1024])
def track_time(batch_size): def track_time(batch_size):
device = utils.get_bench_device() device = utils.get_bench_device()
ds = dgl.data.QM7bDataset() ds = dgl.data.QM7bDataset()
......
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
# The benchmarks for ops edge_softmax # The benchmarks for ops edge_softmax
@utils.benchmark('time', timeout=600) @utils.benchmark("time", timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'cora', 'pubmed']) @utils.parametrize("graph", ["ogbn-arxiv", "reddit", "cora", "pubmed"])
@utils.parametrize('num_heads', [1, 4, 8]) @utils.parametrize("num_heads", [1, 4, 8])
def track_time(graph, num_heads): def track_time(graph, num_heads):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph).to(device) graph = utils.get_graph(graph).to(device)
score = torch.randn((graph.num_edges(),num_heads)).requires_grad_(True).float().to(device) score = (
torch.randn((graph.num_edges(), num_heads))
.requires_grad_(True)
.float()
.to(device)
)
# dry run # dry run
for i in range(3): for i in range(3):
...@@ -22,4 +30,4 @@ def track_time(graph, num_heads): ...@@ -22,4 +30,4 @@ def track_time(graph, num_heads):
for i in range(100): for i in range(100):
y = dgl.ops.edge_softmax(graph, score) y = dgl.ops.edge_softmax(graph, score)
return t.elapsed_secs / 100 return t.elapsed_secs / 100
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
def calc_gflops(graph, feat_size, num_heads, time): def calc_gflops(graph, feat_size, num_heads, time):
return round(2 * graph.num_edges() * feat_size / 1000000000 / time, 2) # count both mul and add return round(
2 * graph.num_edges() * feat_size / 1000000000 / time, 2
) # count both mul and add
# The benchmarks include broadcasting cases. # The benchmarks include broadcasting cases.
# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H) # Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H)
...@@ -14,17 +20,19 @@ def calc_gflops(graph, feat_size, num_heads, time): ...@@ -14,17 +20,19 @@ def calc_gflops(graph, feat_size, num_heads, time):
# matter how many heads are there. # matter how many heads are there.
# If num_heads = 0, it falls back to the normal element-wise operation without # If num_heads = 0, it falls back to the normal element-wise operation without
# broadcasting. # broadcasting.
@utils.benchmark('flops', timeout=600) @utils.benchmark("flops", timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) @utils.parametrize("graph", ["ogbn-arxiv", "reddit", "ogbn-proteins"])
@utils.parametrize('feat_size', [4, 32, 256]) @utils.parametrize("feat_size", [4, 32, 256])
@utils.parametrize('num_heads', [0, 1, 4]) @utils.parametrize("num_heads", [0, 1, 4])
def track_flops(graph, feat_size, num_heads): def track_flops(graph, feat_size, num_heads):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph, format='coo').to(device) graph = utils.get_graph(graph, format="coo").to(device)
if num_heads == 0: if num_heads == 0:
x = torch.randn(graph.num_nodes(), feat_size, device=device) x = torch.randn(graph.num_nodes(), feat_size, device=device)
else: else:
x = torch.randn(graph.num_nodes(), num_heads, feat_size // num_heads, device=device) x = torch.randn(
graph.num_nodes(), num_heads, feat_size // num_heads, device=device
)
# dry run # dry run
for i in range(3): for i in range(3):
......
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
def calc_gflops(graph, feat_size, time): def calc_gflops(graph, feat_size, time):
return round(graph.num_edges() * feat_size / 1000000000 / time, 2) return round(graph.num_edges() * feat_size / 1000000000 / time, 2)
@utils.benchmark('flops', timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) @utils.benchmark("flops", timeout=600)
@utils.parametrize('feat_size', [4, 32, 256]) @utils.parametrize("graph", ["ogbn-arxiv", "reddit", "ogbn-proteins"])
@utils.parametrize('reducer', ['sum', 'max']) @utils.parametrize("feat_size", [4, 32, 256])
@utils.parametrize("reducer", ["sum", "max"])
def track_flops(graph, feat_size, reducer): def track_flops(graph, feat_size, reducer):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph, format='csc').to(device) graph = utils.get_graph(graph, format="csc").to(device)
x = torch.randn(graph.num_nodes(), feat_size, device=device) x = torch.randn(graph.num_nodes(), feat_size, device=device)
if reducer == 'sum': if reducer == "sum":
op = dgl.ops.copy_u_sum op = dgl.ops.copy_u_sum
elif reducer == 'max': elif reducer == "max":
op = dgl.ops.copy_u_max op = dgl.ops.copy_u_max
else: else:
raise ValueError('Invalid reducer', reducer) raise ValueError("Invalid reducer", reducer)
# dry run # dry run
for i in range(3): for i in range(3):
......
import time import time
import dgl
import torch import torch
import dgl
from .. import utils from .. import utils
def calc_gflops(graph, feat_size, num_heads, time): def calc_gflops(graph, feat_size, num_heads, time):
return round(2 * graph.num_edges() * feat_size / 1000000000 / time, 2) # count both mul and add return round(
2 * graph.num_edges() * feat_size / 1000000000 / time, 2
) # count both mul and add
# The benchmarks include broadcasting cases. # The benchmarks include broadcasting cases.
# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H) # Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H)
...@@ -14,18 +20,20 @@ def calc_gflops(graph, feat_size, num_heads, time): ...@@ -14,18 +20,20 @@ def calc_gflops(graph, feat_size, num_heads, time):
# matter how many heads are there. # matter how many heads are there.
# If num_heads = 0, it falls back to the normal element-wise operation without # If num_heads = 0, it falls back to the normal element-wise operation without
# broadcasting. # broadcasting.
@utils.benchmark('flops', timeout=600) @utils.benchmark("flops", timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) @utils.parametrize("graph", ["ogbn-arxiv", "reddit", "ogbn-proteins"])
@utils.parametrize('feat_size', [4, 32, 256]) @utils.parametrize("feat_size", [4, 32, 256])
@utils.parametrize('num_heads', [0, 1, 4]) @utils.parametrize("num_heads", [0, 1, 4])
def track_flops(graph, feat_size, num_heads): def track_flops(graph, feat_size, num_heads):
device = utils.get_bench_device() device = utils.get_bench_device()
graph = utils.get_graph(graph, format='csc').to(device) graph = utils.get_graph(graph, format="csc").to(device)
if num_heads == 0: if num_heads == 0:
x = torch.randn(graph.num_nodes(), feat_size, device=device) x = torch.randn(graph.num_nodes(), feat_size, device=device)
w = torch.randn(graph.num_edges(), feat_size, device=device) w = torch.randn(graph.num_edges(), feat_size, device=device)
else: else:
x = torch.randn(graph.num_nodes(), num_heads, feat_size // num_heads, device=device) x = torch.randn(
graph.num_nodes(), num_heads, feat_size // num_heads, device=device
)
w = torch.randn(graph.num_edges(), num_heads, 1, device=device) w = torch.randn(graph.num_edges(), num_heads, 1, device=device)
# dry run # dry run
......
import dgl
from dgl.nn.pytorch import GATConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv
from .. import utils from .. import utils
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(
num_layers, self,
in_dim, num_layers,
num_hidden, in_dim,
num_classes, num_hidden,
heads, num_classes,
activation, heads,
feat_drop, activation,
attn_drop, feat_drop,
negative_slope, attn_drop,
residual): negative_slope,
residual,
):
super(GAT, self).__init__() super(GAT, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
self.activation = activation self.activation = activation
# input projection (no residual) # input projection (no residual)
self.gat_layers.append(GATConv( self.gat_layers.append(
in_dim, num_hidden, heads[0], GATConv(
feat_drop, attn_drop, negative_slope, False, self.activation)) in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
)
)
# hidden layers # hidden layers
for l in range(1, num_layers): for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv( self.gat_layers.append(
num_hidden * heads[l-1], num_hidden, heads[l], GATConv(
feat_drop, attn_drop, negative_slope, residual, self.activation)) num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
)
)
# output projection # output projection
self.gat_layers.append(GATConv( self.gat_layers.append(
num_hidden * heads[-2], num_classes, heads[-1], GATConv(
feat_drop, attn_drop, negative_slope, residual, None)) num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
)
)
def forward(self, g, inputs): def forward(self, g, inputs):
h = inputs h = inputs
...@@ -45,6 +76,7 @@ class GAT(nn.Module): ...@@ -45,6 +76,7 @@ class GAT(nn.Module):
logits = self.gat_layers[-1](g, h).mean(1) logits = self.gat_layers[-1](g, h).mean(1)
return logits return logits
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -55,19 +87,20 @@ def evaluate(model, g, features, labels, mask): ...@@ -55,19 +87,20 @@ def evaluate(model, g, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) * 100 return correct.item() * 1.0 / len(labels) * 100
@utils.benchmark('acc')
@utils.parametrize('data', ['cora', 'pubmed']) @utils.benchmark("acc")
@utils.parametrize("data", ["cora", "pubmed"])
def track_acc(data): def track_acc(data):
data = utils.process_data(data) data = utils.process_data(data)
device = utils.get_bench_device() device = utils.get_bench_device()
g = data[0].to(device) g = data[0].to(device)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_classes n_classes = data.num_classes
...@@ -76,17 +109,14 @@ def track_acc(data): ...@@ -76,17 +109,14 @@ def track_acc(data):
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create model # create model
model = GAT(1, in_feats, 8, n_classes, [8, 1], F.elu, model = GAT(1, in_feats, 8, n_classes, [8, 1], F.elu, 0.6, 0.6, 0.2, False)
0.6, 0.6, 0.2, False)
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
model = model.to(device) model = model.to(device)
model.train() model.train()
# optimizer # optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
lr=1e-2,
weight_decay=5e-4)
for epoch in range(200): for epoch in range(200):
logits = model(g, features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
......
import dgl
from dgl.nn.pytorch import GraphConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv
from .. import utils from .. import utils
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
n_hidden, ):
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# input layer # input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layers.append(GraphConv(n_hidden, n_classes)) self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
...@@ -33,6 +33,7 @@ class GCN(nn.Module): ...@@ -33,6 +33,7 @@ class GCN(nn.Module):
h = layer(g, h) h = layer(g, h)
return h return h
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -43,19 +44,20 @@ def evaluate(model, g, features, labels, mask): ...@@ -43,19 +44,20 @@ def evaluate(model, g, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) * 100 return correct.item() * 1.0 / len(labels) * 100
@utils.benchmark('acc')
@utils.parametrize('data', ['cora', 'pubmed']) @utils.benchmark("acc")
@utils.parametrize("data", ["cora", "pubmed"])
def track_acc(data): def track_acc(data):
data = utils.process_data(data) data = utils.process_data(data)
device = utils.get_bench_device() device = utils.get_bench_device()
g = data[0].to(device).int() g = data[0].to(device).int()
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_classes n_classes = data.num_classes
...@@ -67,7 +69,7 @@ def track_acc(data): ...@@ -67,7 +69,7 @@ def track_acc(data):
degs = g.in_degrees().float() degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5) norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) g.ndata["norm"] = norm.unsqueeze(1)
# create GCN model # create GCN model
model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5) model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5)
...@@ -77,9 +79,7 @@ def track_acc(data): ...@@ -77,9 +79,7 @@ def track_acc(data):
model.train() model.train()
# optimizer # optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
lr=1e-2,
weight_decay=5e-4)
for epoch in range(200): for epoch in range(200):
logits = model(g, features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from .. import utils from .. import utils
class GraphConv(nn.Module): class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim, activation=None): def __init__(self, in_dim, out_dim, activation=None):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
...@@ -18,39 +20,42 @@ class GraphConv(nn.Module): ...@@ -18,39 +20,42 @@ class GraphConv(nn.Module):
def forward(self, graph, feat): def forward(self, graph, feat):
with graph.local_scope(): with graph.local_scope():
graph.ndata['ci'] = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5) graph.ndata["ci"] = torch.pow(
graph.ndata['cj'] = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5) graph.out_degrees().float().clamp(min=1), -0.5
graph.ndata['h'] = feat )
graph.ndata["cj"] = torch.pow(
graph.in_degrees().float().clamp(min=1), -0.5
)
graph.ndata["h"] = feat
graph.update_all(self.mfunc, self.rfunc) graph.update_all(self.mfunc, self.rfunc)
h = graph.ndata['h'] h = graph.ndata["h"]
h = torch.matmul(h, self.weight) + self.bias h = torch.matmul(h, self.weight) + self.bias
if self.activation is not None: if self.activation is not None:
h = self.activation(h) h = self.activation(h)
return h return h
def mfunc(self, edges): def mfunc(self, edges):
return {'m' : edges.src['h'], 'ci' : edges.src['ci']} return {"m": edges.src["h"], "ci": edges.src["ci"]}
def rfunc(self, nodes): def rfunc(self, nodes):
ci = nodes.mailbox['ci'].unsqueeze(2) ci = nodes.mailbox["ci"].unsqueeze(2)
newh = (nodes.mailbox['m'] * ci).sum(1) * nodes.data['cj'].unsqueeze(1) newh = (nodes.mailbox["m"] * ci).sum(1) * nodes.data["cj"].unsqueeze(1)
return {'h' : newh} return {"h": newh}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
n_hidden, ):
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# input layer # input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layers.append(GraphConv(n_hidden, n_classes)) self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
...@@ -63,6 +68,7 @@ class GCN(nn.Module): ...@@ -63,6 +68,7 @@ class GCN(nn.Module):
h = layer(g, h) h = layer(g, h)
return h return h
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -73,19 +79,20 @@ def evaluate(model, g, features, labels, mask): ...@@ -73,19 +79,20 @@ def evaluate(model, g, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) * 100 return correct.item() * 1.0 / len(labels) * 100
@utils.benchmark('acc', timeout=300)
@utils.parametrize('data', ['cora', 'pubmed']) @utils.benchmark("acc", timeout=300)
@utils.parametrize("data", ["cora", "pubmed"])
def track_acc(data): def track_acc(data):
data = utils.process_data(data) data = utils.process_data(data)
device = utils.get_bench_device() device = utils.get_bench_device()
g = data[0].to(device).int() g = data[0].to(device).int()
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_classes n_classes = data.num_classes
...@@ -97,7 +104,7 @@ def track_acc(data): ...@@ -97,7 +104,7 @@ def track_acc(data):
degs = g.in_degrees().float() degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5) norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) g.ndata["norm"] = norm.unsqueeze(1)
# create GCN model # create GCN model
model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5) model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5)
...@@ -107,9 +114,7 @@ def track_acc(data): ...@@ -107,9 +114,7 @@ def track_acc(data):
model.train() model.train()
# optimizer # optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
lr=1e-2,
weight_decay=5e-4)
for epoch in range(200): for epoch in range(200):
logits = model(g, features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
......
...@@ -2,41 +2,51 @@ import torch ...@@ -2,41 +2,51 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy from torchmetrics.functional import accuracy
from .. import utils
from .. import rgcn
from .. import rgcn, utils
@utils.benchmark('acc', timeout=1200)
@utils.parametrize('dataset', ['aifb', 'mutag']) @utils.benchmark("acc", timeout=1200)
@utils.parametrize('ns_mode', [False]) @utils.parametrize("dataset", ["aifb", "mutag"])
@utils.parametrize("ns_mode", [False])
def track_acc(dataset, ns_mode): def track_acc(dataset, ns_mode):
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = rgcn.load_data( (
dataset, get_norm=True) g,
num_rels,
num_classes,
labels,
train_idx,
test_idx,
target_idx,
) = rgcn.load_data(dataset, get_norm=True)
num_hidden = 16 num_hidden = 16
if dataset == 'aifb': if dataset == "aifb":
num_bases = -1 num_bases = -1
l2norm = 0. l2norm = 0.0
elif dataset == 'mutag': elif dataset == "mutag":
num_bases = 30 num_bases = 30
l2norm = 5e-4 l2norm = 5e-4
elif dataset == 'am': elif dataset == "am":
num_bases = 40 num_bases = 40
l2norm = 5e-4 l2norm = 5e-4
else: else:
raise ValueError() raise ValueError()
model = rgcn.RGCN(g.num_nodes(), model = rgcn.RGCN(
num_hidden, g.num_nodes(),
num_classes, num_hidden,
num_rels, num_classes,
num_bases=num_bases, num_rels,
ns_mode=ns_mode) num_bases=num_bases,
ns_mode=ns_mode,
)
device = utils.get_bench_device() device = utils.get_bench_device()
labels = labels.to(device) labels = labels.to(device)
model = model.to(device) model = model.to(device)
g = g.int().to(device) g = g.int().to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
model.parameters(), lr=1e-2, weight_decay=l2norm) model.parameters(), lr=1e-2, weight_decay=l2norm
)
model.train() model.train()
for epoch in range(30): for epoch in range(30):
...@@ -51,7 +61,6 @@ def track_acc(dataset, ns_mode): ...@@ -51,7 +61,6 @@ def track_acc(dataset, ns_mode):
with torch.no_grad(): with torch.no_grad():
logits = model(g) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax( test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
dim=1), labels[test_idx]).item()
return test_acc return test_acc
import dgl
from dgl.nn.pytorch import SAGEConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import SAGEConv
from .. import utils from .. import utils
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
def __init__(self, def __init__(
in_feats, self,
n_hidden, in_feats,
n_classes, n_hidden,
n_layers, n_classes,
activation, n_layers,
dropout, activation,
aggregator_type): dropout,
aggregator_type,
):
super(GraphSAGE, self).__init__() super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -26,7 +30,9 @@ class GraphSAGE(nn.Module): ...@@ -26,7 +30,9 @@ class GraphSAGE(nn.Module):
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type)) self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
# output layer # output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None self.layers.append(
SAGEConv(n_hidden, n_classes, aggregator_type)
) # activation None
def forward(self, graph, inputs): def forward(self, graph, inputs):
h = self.dropout(inputs) h = self.dropout(inputs)
...@@ -37,6 +43,7 @@ class GraphSAGE(nn.Module): ...@@ -37,6 +43,7 @@ class GraphSAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -47,19 +54,20 @@ def evaluate(model, g, features, labels, mask): ...@@ -47,19 +54,20 @@ def evaluate(model, g, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) * 100 return correct.item() * 1.0 / len(labels) * 100
@utils.benchmark('acc')
@utils.parametrize('data', ['cora', 'pubmed']) @utils.benchmark("acc")
@utils.parametrize("data", ["cora", "pubmed"])
def track_acc(data): def track_acc(data):
data = utils.process_data(data) data = utils.process_data(data)
device = utils.get_bench_device() device = utils.get_bench_device()
g = data[0].to(device) g = data[0].to(device)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_classes n_classes = data.num_classes
...@@ -68,16 +76,14 @@ def track_acc(data): ...@@ -68,16 +76,14 @@ def track_acc(data):
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create model # create model
model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, 'gcn') model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, "gcn")
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
model = model.to(device) model = model.to(device)
model.train() model.train()
# optimizer # optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
lr=1e-2,
weight_decay=5e-4)
for epoch in range(200): for epoch in range(200):
logits = model(g, features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment