Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9d417346
"tests/python/vscode:/vscode.git/clone" did not exist on "55af15d4a9736a530eb53faef4bca15d040090ca"
Unverified
Commit
9d417346
authored
Dec 17, 2023
by
Rhett Ying
Committed by
GitHub
Dec 17, 2023
Browse files
[GraphBolt] update to_dgl() in examples (#6763)
parent
a5e5f11a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
5 additions
and
55 deletions
+5
-55
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+0
-4
examples/sampling/graphbolt/lightning/node_classification.py
examples/sampling/graphbolt/lightning/node_classification.py
+0
-2
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+2
-20
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+0
-5
examples/sampling/graphbolt/quickstart/link_prediction.py
examples/sampling/graphbolt/quickstart/link_prediction.py
+0
-14
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
+3
-9
tutorials/multi/2_node_classification.py
tutorials/multi/2_node_classification.py
+0
-1
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
9d417346
...
...
@@ -153,7 +153,6 @@ def evaluate(rank, model, dataloader, num_classes, device):
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
):
data
=
data
.
to_dgl
()
blocks
=
data
.
blocks
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
...
...
@@ -206,9 +205,6 @@ def train(
if
rank
==
0
else
enumerate
(
train_dataloader
)
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# The input features are from the source nodes in the first
# layer's computation graph.
x
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/lightning/node_classification.py
View file @
9d417346
...
...
@@ -93,7 +93,6 @@ class SAGE(LightningModule):
)
def
training_step
(
self
,
batch
,
batch_idx
):
batch
=
batch
.
to_dgl
()
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
x
=
batch
.
node_features
[
"feat"
]
y
=
batch
.
labels
.
to
(
"cuda"
)
...
...
@@ -111,7 +110,6 @@ class SAGE(LightningModule):
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
batch
=
batch
.
to_dgl
()
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
x
=
batch
.
node_features
[
"feat"
]
y
=
batch
.
labels
.
to
(
"cuda"
)
...
...
examples/sampling/graphbolt/link_prediction.py
View file @
9d417346
...
...
@@ -101,7 +101,6 @@ class SAGE(nn.Module):
)
feature
=
feature
.
to
(
device
)
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
...
...
@@ -237,20 +236,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
return
dataloader
def
to_binary_link_dgl_computing_pack
(
data
:
gb
.
DGLMiniBatch
):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src
,
pos_dst
=
data
.
positive_node_pairs
neg_src
,
neg_dst
=
data
.
negative_node_pairs
node_pairs
=
(
torch
.
cat
((
pos_src
,
neg_src
),
dim
=
0
),
torch
.
cat
((
pos_dst
,
neg_dst
),
dim
=
0
),
)
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
labels
=
torch
.
cat
([
pos_label
,
neg_label
],
dim
=
0
)
return
(
node_pairs
,
labels
.
float
())
@
torch
.
no_grad
()
def
compute_mrr
(
args
,
model
,
evaluator
,
node_emb
,
src
,
dst
,
neg_dst
):
"""Compute the Mean Reciprocal Rank (MRR) for given source and destination
...
...
@@ -324,11 +309,8 @@ def train(args, model, graph, features, train_set):
total_loss
=
0
start_epoch_time
=
time
.
time
()
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# Unpack MiniBatch.
compacted_pairs
,
labels
=
to_binary_link_dgl_computing_pack
(
data
)
# Get node pairs with labels for loss calculation.
compacted_pairs
,
labels
=
data
.
node_pairs_with_labels
node_feature
=
data
.
node_features
[
"feat"
]
# Convert sampled subgraphs to DGL blocks.
blocks
=
data
.
blocks
...
...
examples/sampling/graphbolt/node_classification.py
View file @
9d417346
...
...
@@ -202,7 +202,6 @@ class SAGE(nn.Module):
feature
=
feature
.
to
(
device
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
...
...
@@ -261,7 +260,6 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
...
...
@@ -292,9 +290,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
model
.
train
()
total_loss
=
0
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# The input features from the source nodes in the first layer's
# computation graph.
x
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/quickstart/link_prediction.py
View file @
9d417346
...
...
@@ -76,20 +76,6 @@ class GraphSAGE(nn.Module):
return
hidden_x
def
to_binary_link_dgl_computing_pack
(
data
:
gb
.
MiniBatch
):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src
,
pos_dst
=
data
.
positive_node_pairs
neg_src
,
neg_dst
=
data
.
negative_node_pairs
node_pairs
=
(
torch
.
cat
((
pos_src
,
neg_src
),
dim
=
0
),
torch
.
cat
((
pos_dst
,
neg_dst
),
dim
=
0
),
)
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
labels
=
torch
.
cat
([
pos_label
,
neg_label
],
dim
=
0
)
return
(
node_pairs
,
labels
)
@
torch
.
no_grad
()
def
evaluate
(
model
,
dataset
,
device
):
model
.
eval
()
...
...
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
View file @
9d417346
...
...
@@ -176,7 +176,7 @@ def rel_graph_embed(graph, embed_size):
for the "paper" node type.
"""
node_num
=
{}
node_type_to_id
=
graph
.
metadata
.
node_type_to_id
node_type_to_id
=
graph
.
node_type_to_id
node_type_offset
=
graph
.
node_type_offset
for
ntype
,
ntype_id
in
node_type_to_id
.
items
():
# Skip the "paper" node type.
...
...
@@ -328,12 +328,12 @@ class EntityClassify(nn.Module):
# Generate and sort a list of unique edge types from the input graph.
# eg. ['writes', 'cites']
etypes
=
list
(
graph
.
metadata
.
edge_type_to_id
.
keys
())
etypes
=
list
(
graph
.
edge_type_to_id
.
keys
())
etypes
=
[
gb
.
etype_str_to_tuple
(
etype
)[
1
]
for
etype
in
etypes
]
self
.
relation_names
=
etypes
self
.
relation_names
.
sort
()
self
.
dropout
=
0.5
ntypes
=
list
(
graph
.
metadata
.
node_type_to_id
.
keys
())
ntypes
=
list
(
graph
.
node_type_to_id
.
keys
())
self
.
layers
=
nn
.
ModuleList
()
# First layer: transform input features to hidden features. Use ReLU
...
...
@@ -487,9 +487,6 @@ def evaluate(
y_true
=
list
()
for
data
in
tqdm
(
data_loader
,
desc
=
"Inference"
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
node_features
=
extract_node_features
(
name
,
blocks
[
0
],
data
,
node_embed
,
device
...
...
@@ -558,9 +555,6 @@ def run(
total_loss
=
0
for
data
in
tqdm
(
data_loader
,
desc
=
f
"Training~Epoch
{
epoch
:
02
d
}
"
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
# Convert MiniBatch to DGL Blocks.
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
...
...
tutorials/multi/2_node_classification.py
View file @
9d417346
...
...
@@ -118,7 +118,6 @@ def create_dataloader(
)
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
[
10
,
10
,
10
])
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
to_dgl
()
datapipe
=
datapipe
.
copy_to
(
device
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
0
)
return
dataloader
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment