Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
56ffb650
Unverified
Commit
56ffb650
authored
Jan 06, 2023
by
peizhou001
Committed by
GitHub
Jan 06, 2023
Browse files
[API Deprecation]Deprecate contrib module (#5114)
parent
436de3d1
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
23 deletions
+24
-23
tutorials/models/1_gnn/4_rgcn.py
tutorials/models/1_gnn/4_rgcn.py
+24
-23
No files found.
tutorials/models/1_gnn/4_rgcn.py
View file @
56ffb650
...
@@ -136,6 +136,7 @@ multiple edges among any given pair.
...
@@ -136,6 +136,7 @@ multiple edges among any given pair.
# efficient :class:`builtin R-GCN layer module <dgl.nn.pytorch.conv.RelGraphConv>`.
# efficient :class:`builtin R-GCN layer module <dgl.nn.pytorch.conv.RelGraphConv>`.
#
#
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -194,11 +195,11 @@ class RGCNLayer(nn.Module):
...
@@ -194,11 +195,11 @@ class RGCNLayer(nn.Module):
# for input layer, matrix multiply can be converted to be
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
# an embedding lookup using source node id
embed
=
weight
.
view
(
-
1
,
self
.
out_feat
)
embed
=
weight
.
view
(
-
1
,
self
.
out_feat
)
index
=
edges
.
data
[
'rel_type'
]
*
self
.
in_feat
+
edges
.
src
[
'id'
]
index
=
edges
.
data
[
dgl
.
ETYPE
]
*
self
.
in_feat
+
edges
.
src
[
'id'
]
return
{
'msg'
:
embed
[
index
]
*
edges
.
data
[
'norm'
]}
return
{
'msg'
:
embed
[
index
]
*
edges
.
data
[
'norm'
]}
else
:
else
:
def
message_func
(
edges
):
def
message_func
(
edges
):
w
=
weight
[
edges
.
data
[
'rel_type'
]]
w
=
weight
[
edges
.
data
[
dgl
.
ETYPE
]]
msg
=
torch
.
bmm
(
edges
.
src
[
'h'
].
unsqueeze
(
1
),
w
).
squeeze
()
msg
=
torch
.
bmm
(
edges
.
src
[
'h'
].
unsqueeze
(
1
),
w
).
squeeze
()
msg
=
msg
*
edges
.
data
[
'norm'
]
msg
=
msg
*
edges
.
data
[
'norm'
]
return
{
'msg'
:
msg
}
return
{
'msg'
:
msg
}
...
@@ -278,22 +279,20 @@ class Model(nn.Module):
...
@@ -278,22 +279,20 @@ class Model(nn.Module):
# This tutorial uses Institute for Applied Informatics and Formal Description Methods (AIFB) dataset from R-GCN paper.
# This tutorial uses Institute for Applied Informatics and Formal Description Methods (AIFB) dataset from R-GCN paper.
# load graph data
# load graph data
from
dgl.contrib.data
import
load_data
dataset
=
dgl
.
data
.
rdf
.
AIFBDataset
()
data
=
load_data
(
dataset
=
'aifb'
)
g
=
dataset
[
0
]
num_nodes
=
data
.
num_nodes
category
=
dataset
.
predict_category
num_rels
=
data
.
num_rels
train_mask
=
g
.
nodes
[
category
].
data
.
pop
(
'train_mask'
)
num_classes
=
data
.
num_classes
test_mask
=
g
.
nodes
[
category
].
data
.
pop
(
'test_mask'
)
labels
=
data
.
labels
train_idx
=
torch
.
nonzero
(
train_mask
,
as_tuple
=
False
).
squeeze
()
train_idx
=
data
.
train_idx
test_idx
=
torch
.
nonzero
(
test_mask
,
as_tuple
=
False
).
squeeze
()
# split training and validation set
labels
=
g
.
nodes
[
category
].
data
.
pop
(
'label'
)
val_idx
=
train_idx
[:
len
(
train_idx
)
//
5
]
num_rels
=
len
(
g
.
canonical_etypes
)
train_idx
=
train_idx
[
len
(
train_idx
)
//
5
:]
num_classes
=
dataset
.
num_classes
# normalization factor
# edge type and normalization factor
for
cetype
in
g
.
canonical_etypes
:
edge_type
=
torch
.
from_numpy
(
data
.
edge_type
)
g
.
edges
[
cetype
].
data
[
'norm'
]
=
dgl
.
norm_by_dst
(
g
,
cetype
).
unsqueeze
(
1
)
edge_norm
=
torch
.
from_numpy
(
data
.
edge_norm
).
unsqueeze
(
1
)
category_id
=
g
.
ntypes
.
index
(
category
)
labels
=
torch
.
from_numpy
(
labels
).
view
(
-
1
)
###############################################################################
###############################################################################
# Create graph and model
# Create graph and model
...
@@ -308,8 +307,9 @@ lr = 0.01 # learning rate
...
@@ -308,8 +307,9 @@ lr = 0.01 # learning rate
l2norm
=
0
# L2 norm coefficient
l2norm
=
0
# L2 norm coefficient
# create graph
# create graph
g
=
DGLGraph
((
data
.
edge_src
,
data
.
edge_dst
))
g
=
dgl
.
to_homogeneous
(
g
,
edata
=
[
'norm'
])
g
.
edata
.
update
({
'rel_type'
:
edge_type
,
'norm'
:
edge_norm
})
node_ids
=
torch
.
arange
(
g
.
num_nodes
())
target_idx
=
node_ids
[
g
.
ndata
[
dgl
.
NTYPE
]
==
category_id
]
# create model
# create model
model
=
Model
(
g
.
num_nodes
(),
model
=
Model
(
g
.
num_nodes
(),
...
@@ -331,6 +331,7 @@ model.train()
...
@@ -331,6 +331,7 @@ model.train()
for
epoch
in
range
(
n_epochs
):
for
epoch
in
range
(
n_epochs
):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
logits
=
model
.
forward
(
g
)
logits
=
model
.
forward
(
g
)
logits
=
logits
[
target_idx
]
loss
=
F
.
cross_entropy
(
logits
[
train_idx
],
labels
[
train_idx
])
loss
=
F
.
cross_entropy
(
logits
[
train_idx
],
labels
[
train_idx
])
loss
.
backward
()
loss
.
backward
()
...
@@ -338,9 +339,9 @@ for epoch in range(n_epochs):
...
@@ -338,9 +339,9 @@ for epoch in range(n_epochs):
train_acc
=
torch
.
sum
(
logits
[
train_idx
].
argmax
(
dim
=
1
)
==
labels
[
train_idx
])
train_acc
=
torch
.
sum
(
logits
[
train_idx
].
argmax
(
dim
=
1
)
==
labels
[
train_idx
])
train_acc
=
train_acc
.
item
()
/
len
(
train_idx
)
train_acc
=
train_acc
.
item
()
/
len
(
train_idx
)
val_loss
=
F
.
cross_entropy
(
logits
[
val
_idx
],
labels
[
val
_idx
])
val_loss
=
F
.
cross_entropy
(
logits
[
test
_idx
],
labels
[
test
_idx
])
val_acc
=
torch
.
sum
(
logits
[
val
_idx
].
argmax
(
dim
=
1
)
==
labels
[
val
_idx
])
val_acc
=
torch
.
sum
(
logits
[
test
_idx
].
argmax
(
dim
=
1
)
==
labels
[
test
_idx
])
val_acc
=
val_acc
.
item
()
/
len
(
val
_idx
)
val_acc
=
val_acc
.
item
()
/
len
(
test
_idx
)
print
(
"Epoch {:05d} | "
.
format
(
epoch
)
+
print
(
"Epoch {:05d} | "
.
format
(
epoch
)
+
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | "
.
format
(
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | "
.
format
(
train_acc
,
loss
.
item
())
+
train_acc
,
loss
.
item
())
+
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment