Unverified Commit 9836f78e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

autoformat (#5322)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 704bcaf6
import time import time
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import SAGEConv from dgl.nn.pytorch import SAGEConv
from .. import utils from .. import utils
......
import time import time
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import HeteroGraphConv, SAGEConv from dgl.nn.pytorch import HeteroGraphConv, SAGEConv
from .. import utils from .. import utils
......
import time import time
import numpy as np
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
from .. import utils from .. import utils
......
import time import time
import numpy as np
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch
from .. import utils from .. import utils
......
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
from .. import utils from .. import utils
......
import time import time
import numpy as np
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch
from .. import utils from .. import utils
......
import time import time
import numpy as np
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch
from .. import utils from .. import utils
......
import time import time
import numpy as np
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import time import time
import torch
import dgl import dgl
import torch
from .. import utils from .. import utils
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
from .. import utils from .. import utils
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
from .. import utils from .. import utils
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from .. import utils from .. import utils
......
import dgl
import itertools import itertools
import time
import dgl
import dgl.nn.pytorch as dglnn
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.nn.pytorch as dglnn
from dgl.nn import RelGraphConv from dgl.nn import RelGraphConv
import time from torch.utils.data import DataLoader
from .. import utils from .. import utils
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
""" Entity classification class for RGCN """Entity classification class for RGCN
Parameters Parameters
---------- ----------
device : int device : int
...@@ -35,17 +37,20 @@ class EntityClassify(nn.Module): ...@@ -35,17 +37,20 @@ class EntityClassify(nn.Module):
use_self_loop : bool use_self_loop : bool
Use self loop if True, default False. Use self loop if True, default False.
""" """
def __init__(self,
device, def __init__(
num_nodes, self,
h_dim, device,
out_dim, num_nodes,
num_rels, h_dim,
num_bases=None, out_dim,
num_hidden_layers=1, num_rels,
dropout=0, num_bases=None,
use_self_loop=False, num_hidden_layers=1,
layer_norm=False): dropout=0,
use_self_loop=False,
layer_norm=False,
):
super(EntityClassify, self).__init__() super(EntityClassify, self).__init__()
self.device = device self.device = device
self.num_nodes = num_nodes self.num_nodes = num_nodes
...@@ -60,22 +65,47 @@ class EntityClassify(nn.Module): ...@@ -60,22 +65,47 @@ class EntityClassify(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConv( self.layers.append(
self.h_dim, self.h_dim, self.num_rels, "basis", RelGraphConv(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout, layer_norm = layer_norm)) self.h_dim,
self.num_rels,
"basis",
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
layer_norm=layer_norm,
)
)
# h2h # h2h
for idx in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
self.layers.append(RelGraphConv( self.layers.append(
self.h_dim, self.h_dim, self.num_rels, "basis", RelGraphConv(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout, layer_norm = layer_norm)) self.h_dim,
self.num_rels,
"basis",
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
layer_norm=layer_norm,
)
)
# h2o # h2o
self.layers.append(RelGraphConv( self.layers.append(
self.h_dim, self.out_dim, self.num_rels, "basis", RelGraphConv(
self.num_bases, activation=None, self.h_dim,
self_loop=self.use_self_loop, self.out_dim,
layer_norm = layer_norm)) self.num_rels,
"basis",
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
layer_norm=layer_norm,
)
)
def forward(self, blocks, feats, norm=None): def forward(self, blocks, feats, norm=None):
if blocks is None: if blocks is None:
...@@ -84,9 +114,10 @@ class EntityClassify(nn.Module): ...@@ -84,9 +114,10 @@ class EntityClassify(nn.Module):
h = feats h = feats
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
block = block.to(self.device) block = block.to(self.device)
h = layer(block, h, block.edata['etype'], block.edata['norm']) h = layer(block, h, block.edata["etype"], block.edata["norm"])
return h return h
class RelGraphEmbedLayer(nn.Module): class RelGraphEmbedLayer(nn.Module):
r"""Embedding layer for featureless heterograph. r"""Embedding layer for featureless heterograph.
Parameters Parameters
...@@ -107,15 +138,18 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -107,15 +138,18 @@ class RelGraphEmbedLayer(nn.Module):
embed_name : str, optional embed_name : str, optional
Embed name Embed name
""" """
def __init__(self,
device, def __init__(
num_nodes, self,
node_tids, device,
num_of_ntype, num_nodes,
input_size, node_tids,
embed_size, num_of_ntype,
sparse_emb=False, input_size,
embed_name='embed'): embed_size,
sparse_emb=False,
embed_name="embed",
):
super(RelGraphEmbedLayer, self).__init__() super(RelGraphEmbedLayer, self).__init__()
self.device = device self.device = device
self.embed_size = embed_size self.embed_size = embed_size
...@@ -135,7 +169,9 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -135,7 +169,9 @@ class RelGraphEmbedLayer(nn.Module):
nn.init.xavier_uniform_(embed) nn.init.xavier_uniform_(embed)
self.embeds[str(ntype)] = embed self.embeds[str(ntype)] = embed
self.node_embeds = th.nn.Embedding(node_tids.shape[0], self.embed_size, sparse=self.sparse_emb) self.node_embeds = th.nn.Embedding(
node_tids.shape[0], self.embed_size, sparse=self.sparse_emb
)
nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0) nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0)
def forward(self, node_ids, node_tids, type_ids, features): def forward(self, node_ids, node_tids, type_ids, features):
...@@ -157,57 +193,65 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -157,57 +193,65 @@ class RelGraphEmbedLayer(nn.Module):
embeddings as the input of the next layer embeddings as the input of the next layer
""" """
tsd_ids = node_ids.to(self.node_embeds.weight.device) tsd_ids = node_ids.to(self.node_embeds.weight.device)
embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.device) embeds = th.empty(
node_ids.shape[0], self.embed_size, device=self.device
)
for ntype in range(self.num_of_ntype): for ntype in range(self.num_of_ntype):
if features[ntype] is not None: if features[ntype] is not None:
loc = node_tids == ntype loc = node_tids == ntype
embeds[loc] = features[ntype][type_ids[loc]].to(self.device) @ self.embeds[str(ntype)].to(self.device) embeds[loc] = features[ntype][type_ids[loc]].to(
self.device
) @ self.embeds[str(ntype)].to(self.device)
else: else:
loc = node_tids == ntype loc = node_tids == ntype
embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.device) embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.device)
return embeds return embeds
def evaluate(model, embed_layer, eval_loader, node_feats): def evaluate(model, embed_layer, eval_loader, node_feats):
model.eval() model.eval()
embed_layer.eval() embed_layer.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
with th.no_grad(): with th.no_grad():
for sample_data in eval_loader: for sample_data in eval_loader:
th.cuda.empty_cache() th.cuda.empty_cache()
_, _, blocks = sample_data _, _, blocks = sample_data
feats = embed_layer(blocks[0].srcdata[dgl.NID], feats = embed_layer(
blocks[0].srcdata[dgl.NTYPE], blocks[0].srcdata[dgl.NID],
blocks[0].srcdata['type_id'], blocks[0].srcdata[dgl.NTYPE],
node_feats) blocks[0].srcdata["type_id"],
node_feats,
)
logits = model(blocks, feats) logits = model(blocks, feats)
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
eval_seeds.append(blocks[-1].dstdata['type_id'].cpu().detach()) eval_seeds.append(blocks[-1].dstdata["type_id"].cpu().detach())
eval_logits = th.cat(eval_logits) eval_logits = th.cat(eval_logits)
eval_seeds = th.cat(eval_seeds) eval_seeds = th.cat(eval_seeds)
return eval_logits, eval_seeds return eval_logits, eval_seeds
@utils.benchmark('acc', timeout=3600) # ogbn-mag takes ~1 hour to train
@utils.parametrize('data', ['am', 'ogbn-mag']) @utils.benchmark("acc", timeout=3600) # ogbn-mag takes ~1 hour to train
@utils.parametrize("data", ["am", "ogbn-mag"])
def track_acc(data): def track_acc(data):
dataset = utils.process_data(data) dataset = utils.process_data(data)
device = utils.get_bench_device() device = utils.get_bench_device()
if data == 'am': if data == "am":
n_bases = 40 n_bases = 40
l2norm = 5e-4 l2norm = 5e-4
n_epochs = 20 n_epochs = 20
elif data == 'ogbn-mag': elif data == "ogbn-mag":
n_bases = 2 n_bases = 2
l2norm = 0 l2norm = 0
n_epochs = 20 n_epochs = 20
else: else:
raise ValueError() raise ValueError()
fanouts = [25,15] fanouts = [25, 15]
n_layers = 2 n_layers = 2
batch_size = 1024 batch_size = 1024
n_hidden = 64 n_hidden = 64
...@@ -219,20 +263,20 @@ def track_acc(data): ...@@ -219,20 +263,20 @@ def track_acc(data):
hg = dataset[0] hg = dataset[0]
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
train_mask = hg.nodes[category].data.pop('train_mask') train_mask = hg.nodes[category].data.pop("train_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_mask = hg.nodes[category].data.pop('test_mask') test_mask = hg.nodes[category].data.pop("test_mask")
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = hg.nodes[category].data.pop('labels').to(device) labels = hg.nodes[category].data.pop("labels").to(device)
num_of_ntype = len(hg.ntypes) num_of_ntype = len(hg.ntypes)
num_rels = len(hg.canonical_etypes) num_rels = len(hg.canonical_etypes)
node_feats = [] node_feats = []
for ntype in hg.ntypes: for ntype in hg.ntypes:
if len(hg.nodes[ntype].data) == 0 or 'feat' not in hg.nodes[ntype].data: if len(hg.nodes[ntype].data) == 0 or "feat" not in hg.nodes[ntype].data:
node_feats.append(None) node_feats.append(None)
else: else:
feat = hg.nodes[ntype].data.pop('feat') feat = hg.nodes[ntype].data.pop("feat")
node_feats.append(feat.share_memory_()) node_feats.append(feat.share_memory_())
# get target category id # get target category id
...@@ -241,25 +285,27 @@ def track_acc(data): ...@@ -241,25 +285,27 @@ def track_acc(data):
if ntype == category: if ntype == category:
category_id = i category_id = i
g = dgl.to_homogeneous(hg) g = dgl.to_homogeneous(hg)
u, v, eid = g.all_edges(form='all') u, v, eid = g.all_edges(form="all")
# global norm # global norm
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True) _, inverse_index, count = th.unique(
v, return_inverse=True, return_counts=True
)
degrees = count[inverse_index] degrees = count[inverse_index]
norm = th.ones(eid.shape[0]) / degrees norm = th.ones(eid.shape[0]) / degrees
norm = norm.unsqueeze(1) norm = norm.unsqueeze(1)
g.edata['norm'] = norm g.edata["norm"] = norm
g.edata['etype'] = g.edata[dgl.ETYPE] g.edata["etype"] = g.edata[dgl.ETYPE]
g.ndata['type_id'] = g.ndata[dgl.NID] g.ndata["type_id"] = g.ndata[dgl.NID]
g.ndata['ntype'] = g.ndata[dgl.NTYPE] g.ndata["ntype"] = g.ndata[dgl.NTYPE]
node_ids = th.arange(g.number_of_nodes()) node_ids = th.arange(g.number_of_nodes())
# find out the target node ids # find out the target node ids
node_tids = g.ndata[dgl.NTYPE] node_tids = g.ndata[dgl.NTYPE]
loc = (node_tids == category_id) loc = node_tids == category_id
target_nids = node_ids[loc] target_nids = node_ids[loc]
g = g.formats('csc') g = g.formats("csc")
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
train_loader = dgl.dataloading.DataLoader( train_loader = dgl.dataloading.DataLoader(
g, g,
...@@ -268,7 +314,8 @@ def track_acc(data): ...@@ -268,7 +314,8 @@ def track_acc(data):
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers,
)
test_loader = dgl.dataloading.DataLoader( test_loader = dgl.dataloading.DataLoader(
g, g,
target_nids[test_idx], target_nids[test_idx],
...@@ -276,37 +323,46 @@ def track_acc(data): ...@@ -276,37 +323,46 @@ def track_acc(data):
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers,
)
# node features # node features
# None for one-hot feature, if not none, it should be the feature tensor. # None for one-hot feature, if not none, it should be the feature tensor.
embed_layer = RelGraphEmbedLayer(device, embed_layer = RelGraphEmbedLayer(
g.number_of_nodes(), device,
node_tids, g.number_of_nodes(),
num_of_ntype, node_tids,
node_feats, num_of_ntype,
n_hidden, node_feats,
sparse_emb=True) n_hidden,
sparse_emb=True,
)
# create model # create model
# all model params are in device. # all model params are in device.
model = EntityClassify(device, model = EntityClassify(
g.number_of_nodes(), device,
n_hidden, g.number_of_nodes(),
num_classes, n_hidden,
num_rels, num_classes,
num_bases=n_bases, num_rels,
num_hidden_layers=n_layers - 2, num_bases=n_bases,
dropout=dropout, num_hidden_layers=n_layers - 2,
use_self_loop=use_self_loop, dropout=dropout,
layer_norm=False) use_self_loop=use_self_loop,
layer_norm=False,
)
embed_layer = embed_layer.to(device) embed_layer = embed_layer.to(device)
model = model.to(device) model = model.to(device)
all_params = itertools.chain(model.parameters(), embed_layer.embeds.parameters()) all_params = itertools.chain(
model.parameters(), embed_layer.embeds.parameters()
)
optimizer = th.optim.Adam(all_params, lr=lr, weight_decay=l2norm) optimizer = th.optim.Adam(all_params, lr=lr, weight_decay=l2norm)
emb_optimizer = th.optim.SparseAdam(list(embed_layer.node_embeds.parameters()), lr=lr) emb_optimizer = th.optim.SparseAdam(
list(embed_layer.node_embeds.parameters()), lr=lr
)
print("start training...") print("start training...")
for epoch in range(n_epochs): for epoch in range(n_epochs):
...@@ -315,12 +371,14 @@ def track_acc(data): ...@@ -315,12 +371,14 @@ def track_acc(data):
for i, sample_data in enumerate(train_loader): for i, sample_data in enumerate(train_loader):
input_nodes, output_nodes, blocks = sample_data input_nodes, output_nodes, blocks = sample_data
feats = embed_layer(input_nodes, feats = embed_layer(
blocks[0].srcdata['ntype'], input_nodes,
blocks[0].srcdata['type_id'], blocks[0].srcdata["ntype"],
node_feats) blocks[0].srcdata["type_id"],
node_feats,
)
logits = model(blocks, feats) logits = model(blocks, feats)
seed_idx = blocks[-1].dstdata['type_id'] seed_idx = blocks[-1].dstdata["type_id"]
loss = F.cross_entropy(logits, labels[seed_idx]) loss = F.cross_entropy(logits, labels[seed_idx])
optimizer.zero_grad() optimizer.zero_grad()
emb_optimizer.zero_grad() emb_optimizer.zero_grad()
...@@ -329,10 +387,14 @@ def track_acc(data): ...@@ -329,10 +387,14 @@ def track_acc(data):
optimizer.step() optimizer.step()
emb_optimizer.step() emb_optimizer.step()
print('start testing...') print("start testing...")
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, node_feats) test_logits, test_seeds = evaluate(
model, embed_layer, test_loader, node_feats
)
test_loss = F.cross_entropy(test_logits, labels[test_seeds].cpu()).item() test_loss = F.cross_entropy(test_logits, labels[test_seeds].cpu()).item()
test_acc = th.sum(test_logits.argmax(dim=1) == labels[test_seeds].cpu()).item() / len(test_seeds) test_acc = th.sum(
test_logits.argmax(dim=1) == labels[test_seeds].cpu()
).item() / len(test_seeds)
return test_acc return test_acc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment