Unverified Commit 56ffb650 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation]Deprecate contrib module (#5114)

parent 436de3d1
...@@ -136,6 +136,7 @@ multiple edges among any given pair. ...@@ -136,6 +136,7 @@ multiple edges among any given pair.
# efficient :class:`builtin R-GCN layer module <dgl.nn.pytorch.conv.RelGraphConv>`. # efficient :class:`builtin R-GCN layer module <dgl.nn.pytorch.conv.RelGraphConv>`.
# #
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -194,11 +195,11 @@ class RGCNLayer(nn.Module): ...@@ -194,11 +195,11 @@ class RGCNLayer(nn.Module):
# for input layer, matrix multiply can be converted to be # for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id # an embedding lookup using source node id
embed = weight.view(-1, self.out_feat) embed = weight.view(-1, self.out_feat)
index = edges.data['rel_type'] * self.in_feat + edges.src['id'] index = edges.data[dgl.ETYPE] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']} return {'msg': embed[index] * edges.data['norm']}
else: else:
def message_func(edges): def message_func(edges):
w = weight[edges.data['rel_type']] w = weight[edges.data[dgl.ETYPE]]
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
...@@ -278,22 +279,20 @@ class Model(nn.Module): ...@@ -278,22 +279,20 @@ class Model(nn.Module):
# This tutorial uses Institute for Applied Informatics and Formal Description Methods (AIFB) dataset from R-GCN paper. # This tutorial uses Institute for Applied Informatics and Formal Description Methods (AIFB) dataset from R-GCN paper.
# load graph data # load graph data
from dgl.contrib.data import load_data dataset = dgl.data.rdf.AIFBDataset()
data = load_data(dataset='aifb') g = dataset[0]
num_nodes = data.num_nodes category = dataset.predict_category
num_rels = data.num_rels train_mask = g.nodes[category].data.pop('train_mask')
num_classes = data.num_classes test_mask = g.nodes[category].data.pop('test_mask')
labels = data.labels train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
train_idx = data.train_idx test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
# split training and validation set labels = g.nodes[category].data.pop('label')
val_idx = train_idx[:len(train_idx) // 5] num_rels = len(g.canonical_etypes)
train_idx = train_idx[len(train_idx) // 5:] num_classes = dataset.num_classes
# normalization factor
# edge type and normalization factor for cetype in g.canonical_etypes:
edge_type = torch.from_numpy(data.edge_type) g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) category_id = g.ntypes.index(category)
labels = torch.from_numpy(labels).view(-1)
############################################################################### ###############################################################################
# Create graph and model # Create graph and model
...@@ -308,8 +307,9 @@ lr = 0.01 # learning rate ...@@ -308,8 +307,9 @@ lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient l2norm = 0 # L2 norm coefficient
# create graph # create graph
g = DGLGraph((data.edge_src, data.edge_dst)) g = dgl.to_homogeneous(g, edata=['norm'])
g.edata.update({'rel_type': edge_type, 'norm': edge_norm}) node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create model # create model
model = Model(g.num_nodes(), model = Model(g.num_nodes(),
...@@ -331,6 +331,7 @@ model.train() ...@@ -331,6 +331,7 @@ model.train()
for epoch in range(n_epochs): for epoch in range(n_epochs):
optimizer.zero_grad() optimizer.zero_grad()
logits = model.forward(g) logits = model.forward(g)
logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward() loss.backward()
...@@ -338,9 +339,9 @@ for epoch in range(n_epochs): ...@@ -338,9 +339,9 @@ for epoch in range(n_epochs):
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]) train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
train_acc = train_acc.item() / len(train_idx) train_acc = train_acc.item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) val_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]) val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx])
val_acc = val_acc.item() / len(val_idx) val_acc = val_acc.item() / len(test_idx)
print("Epoch {:05d} | ".format(epoch) + print("Epoch {:05d} | ".format(epoch) +
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format( "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()) + train_acc, loss.item()) +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment