Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0348ad3d
Unverified
Commit
0348ad3d
authored
Dec 07, 2023
by
Ramon Zhou
Committed by
GitHub
Dec 07, 2023
Browse files
[GraphBolt] Move to_dgl out of data loader in examples. (#6705)
parent
7e12f973
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
37 deletions
+33
-37
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+4
-1
examples/sampling/graphbolt/lightning/node_classification.py
examples/sampling/graphbolt/lightning/node_classification.py
+3
-2
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+4
-12
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+6
-13
examples/sampling/graphbolt/quickstart/link_prediction.py
examples/sampling/graphbolt/quickstart/link_prediction.py
+6
-3
examples/sampling/graphbolt/quickstart/node_classification.py
...ples/sampling/graphbolt/quickstart/node_classification.py
+4
-3
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
+6
-3
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
0348ad3d
...
@@ -128,7 +128,6 @@ def create_dataloader(
...
@@ -128,7 +128,6 @@ def create_dataloader(
)
)
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
to_dgl
()
############################################################################
############################################################################
# [Note]:
# [Note]:
...
@@ -154,6 +153,7 @@ def evaluate(rank, model, dataloader, num_classes, device):
...
@@ -154,6 +153,7 @@ def evaluate(rank, model, dataloader, num_classes, device):
for
step
,
data
in
(
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
):
):
data
=
data
.
to_dgl
()
blocks
=
data
.
blocks
blocks
=
data
.
blocks
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y
.
append
(
data
.
labels
)
...
@@ -206,6 +206,9 @@ def train(
...
@@ -206,6 +206,9 @@ def train(
if
rank
==
0
if
rank
==
0
else
enumerate
(
train_dataloader
)
else
enumerate
(
train_dataloader
)
):
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# The input features are from the source nodes in the first
# The input features are from the source nodes in the first
# layer's computation graph.
# layer's computation graph.
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/lightning/node_classification.py
View file @
0348ad3d
...
@@ -93,6 +93,7 @@ class SAGE(LightningModule):
...
@@ -93,6 +93,7 @@ class SAGE(LightningModule):
)
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
batch
=
batch
.
to_dgl
()
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
x
=
batch
.
node_features
[
"feat"
]
x
=
batch
.
node_features
[
"feat"
]
y
=
batch
.
labels
.
to
(
"cuda"
)
y
=
batch
.
labels
.
to
(
"cuda"
)
...
@@ -110,6 +111,7 @@ class SAGE(LightningModule):
...
@@ -110,6 +111,7 @@ class SAGE(LightningModule):
return
loss
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
batch
=
batch
.
to_dgl
()
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
blocks
]
x
=
batch
.
node_features
[
"feat"
]
x
=
batch
.
node_features
[
"feat"
]
y
=
batch
.
labels
.
to
(
"cuda"
)
y
=
batch
.
labels
.
to
(
"cuda"
)
...
@@ -158,7 +160,6 @@ class DataModule(LightningDataModule):
...
@@ -158,7 +160,6 @@ class DataModule(LightningDataModule):
)
)
datapipe
=
sampler
(
self
.
graph
,
self
.
fanouts
)
datapipe
=
sampler
(
self
.
graph
,
self
.
fanouts
)
datapipe
=
datapipe
.
fetch_feature
(
self
.
feature_store
,
[
"feat"
])
datapipe
=
datapipe
.
fetch_feature
(
self
.
feature_store
,
[
"feat"
])
datapipe
=
datapipe
.
to_dgl
()
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
self
.
num_workers
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
self
.
num_workers
)
return
dataloader
return
dataloader
...
@@ -214,7 +215,7 @@ if __name__ == "__main__":
...
@@ -214,7 +215,7 @@ if __name__ == "__main__":
args
.
num_workers
,
args
.
num_workers
,
)
)
in_size
=
dataset
.
feature
.
size
(
"node"
,
None
,
"feat"
)[
0
]
in_size
=
dataset
.
feature
.
size
(
"node"
,
None
,
"feat"
)[
0
]
model
=
SAGE
(
in_size
,
256
,
datamodule
.
num_classes
)
.
to
(
torch
.
double
)
model
=
SAGE
(
in_size
,
256
,
datamodule
.
num_classes
)
# Train.
# Train.
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
"val_acc"
,
mode
=
"max"
)
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
"val_acc"
,
mode
=
"max"
)
...
...
examples/sampling/graphbolt/link_prediction.py
View file @
0348ad3d
...
@@ -100,6 +100,7 @@ class SAGE(nn.Module):
...
@@ -100,6 +100,7 @@ class SAGE(nn.Module):
)
)
feature
=
feature
.
to
(
device
)
feature
=
feature
.
to
(
device
)
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
feature
[
data
.
input_nodes
]
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
if
not
is_last_layer
:
...
@@ -207,18 +208,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
...
@@ -207,18 +208,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
if
is_train
:
if
is_train
:
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
############################################################################
# [Step-4]:
# datapipe.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe
=
datapipe
.
to_dgl
()
############################################################################
############################################################################
# [Input]:
# [Input]:
# 'device': The device to copy the data to.
# 'device': The device to copy the data to.
...
@@ -332,6 +321,9 @@ def train(args, model, graph, features, train_set):
...
@@ -332,6 +321,9 @@ def train(args, model, graph, features, train_set):
total_loss
=
0
total_loss
=
0
start_epoch_time
=
time
.
time
()
start_epoch_time
=
time
.
time
()
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# Unpack MiniBatch.
# Unpack MiniBatch.
compacted_pairs
,
labels
=
to_binary_link_dgl_computing_pack
(
data
)
compacted_pairs
,
labels
=
to_binary_link_dgl_computing_pack
(
data
)
node_feature
=
data
.
node_features
[
"feat"
]
node_feature
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/node_classification.py
View file @
0348ad3d
...
@@ -126,18 +126,6 @@ def create_dataloader(
...
@@ -126,18 +126,6 @@ def create_dataloader(
############################################################################
############################################################################
# [Step-4]:
# [Step-4]:
# self.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe
=
datapipe
.
to_dgl
()
############################################################################
# [Step-5]:
# self.copy_to()
# self.copy_to()
# [Input]:
# [Input]:
# 'device': The device to copy the data to.
# 'device': The device to copy the data to.
...
@@ -147,7 +135,7 @@ def create_dataloader(
...
@@ -147,7 +135,7 @@ def create_dataloader(
datapipe
=
datapipe
.
copy_to
(
device
=
device
)
datapipe
=
datapipe
.
copy_to
(
device
=
device
)
############################################################################
############################################################################
# [Step-
6
]:
# [Step-
5
]:
# gb.DataLoader()
# gb.DataLoader()
# [Input]:
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
# 'datapipe': The datapipe object to be used for data loading.
...
@@ -214,6 +202,7 @@ class SAGE(nn.Module):
...
@@ -214,6 +202,7 @@ class SAGE(nn.Module):
feature
=
feature
.
to
(
device
)
feature
=
feature
.
to
(
device
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
feature
[
data
.
input_nodes
]
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
if
not
is_last_layer
:
...
@@ -272,6 +261,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
...
@@ -272,6 +261,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
data
=
data
.
to_dgl
()
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
...
@@ -302,6 +292,9 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
...
@@ -302,6 +292,9 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
model
.
train
()
model
.
train
()
total_loss
=
0
total_loss
=
0
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format.
data
=
data
.
to_dgl
()
# The input features from the source nodes in the first layer's
# The input features from the source nodes in the first layer's
# computation graph.
# computation graph.
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/quickstart/link_prediction.py
View file @
0348ad3d
...
@@ -47,9 +47,6 @@ def create_dataloader(dateset, device, is_train=True):
...
@@ -47,9 +47,6 @@ def create_dataloader(dateset, device, is_train=True):
dataset
.
feature
,
node_feature_keys
=
[
"feat"
]
dataset
.
feature
,
node_feature_keys
=
[
"feat"
]
)
)
# Convert the mini-batch to DGL format to train a DGL model.
datapipe
=
datapipe
.
to_dgl
()
# Copy the mini-batch to the designated device for training.
# Copy the mini-batch to the designated device for training.
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
...
@@ -101,6 +98,9 @@ def evaluate(model, dataset, device):
...
@@ -101,6 +98,9 @@ def evaluate(model, dataset, device):
logits
=
[]
logits
=
[]
labels
=
[]
labels
=
[]
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
# Unpack MiniBatch.
# Unpack MiniBatch.
compacted_pairs
,
label
=
to_binary_link_dgl_computing_pack
(
data
)
compacted_pairs
,
label
=
to_binary_link_dgl_computing_pack
(
data
)
...
@@ -140,6 +140,9 @@ def train(model, dataset, device):
...
@@ -140,6 +140,9 @@ def train(model, dataset, device):
# mini-batches.
# mini-batches.
########################################################################
########################################################################
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
# Unpack MiniBatch.
# Unpack MiniBatch.
compacted_pairs
,
labels
=
to_binary_link_dgl_computing_pack
(
data
)
compacted_pairs
,
labels
=
to_binary_link_dgl_computing_pack
(
data
)
...
...
examples/sampling/graphbolt/quickstart/node_classification.py
View file @
0348ad3d
...
@@ -25,9 +25,6 @@ def create_dataloader(dateset, itemset, device):
...
@@ -25,9 +25,6 @@ def create_dataloader(dateset, itemset, device):
dataset
.
feature
,
node_feature_keys
=
[
"feat"
]
dataset
.
feature
,
node_feature_keys
=
[
"feat"
]
)
)
# Convert the mini-batch to DGL format to train a DGL model.
datapipe
=
datapipe
.
to_dgl
()
# Copy the mini-batch to the designated device for training.
# Copy the mini-batch to the designated device for training.
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
...
@@ -60,6 +57,7 @@ def evaluate(model, dataset, itemset, device):
...
@@ -60,6 +57,7 @@ def evaluate(model, dataset, itemset, device):
dataloader
=
create_dataloader
(
dataset
,
itemset
,
device
)
dataloader
=
create_dataloader
(
dataset
,
itemset
,
device
)
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
data
=
data
.
to_dgl
()
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
...
@@ -86,6 +84,9 @@ def train(model, dataset, device):
...
@@ -86,6 +84,9 @@ def train(model, dataset, device):
# mini-batches.
# mini-batches.
########################################################################
########################################################################
for
step
,
data
in
enumerate
(
dataloader
):
for
step
,
data
in
enumerate
(
dataloader
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
# The features of sampled nodes.
# The features of sampled nodes.
x
=
data
.
node_features
[
"feat"
]
x
=
data
.
node_features
[
"feat"
]
...
...
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
View file @
0348ad3d
...
@@ -124,9 +124,6 @@ def create_dataloader(
...
@@ -124,9 +124,6 @@ def create_dataloader(
node_feature_keys
[
"institution"
]
=
[
"feat"
]
node_feature_keys
[
"institution"
]
=
[
"feat"
]
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
)
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
)
# Convert a mini-batch to dgl mini-batch for computing.
datapipe
=
datapipe
.
to_dgl
()
# Move the mini-batch to the appropriate device.
# Move the mini-batch to the appropriate device.
# `device`:
# `device`:
# The device to move the mini-batch to.
# The device to move the mini-batch to.
...
@@ -490,6 +487,9 @@ def evaluate(
...
@@ -490,6 +487,9 @@ def evaluate(
y_true
=
list
()
y_true
=
list
()
for
data
in
tqdm
(
data_loader
,
desc
=
"Inference"
):
for
data
in
tqdm
(
data_loader
,
desc
=
"Inference"
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
node_features
=
extract_node_features
(
node_features
=
extract_node_features
(
name
,
blocks
[
0
],
data
,
node_embed
,
device
name
,
blocks
[
0
],
data
,
node_embed
,
device
...
@@ -558,6 +558,9 @@ def run(
...
@@ -558,6 +558,9 @@ def run(
total_loss
=
0
total_loss
=
0
for
data
in
tqdm
(
data_loader
,
desc
=
f
"Training~Epoch
{
epoch
:
02
d
}
"
):
for
data
in
tqdm
(
data_loader
,
desc
=
f
"Training~Epoch
{
epoch
:
02
d
}
"
):
# Convert data to DGL format for computing.
data
=
data
.
to_dgl
()
# Convert MiniBatch to DGL Blocks.
# Convert MiniBatch to DGL Blocks.
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment