Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4c6e6543
Unverified
Commit
4c6e6543
authored
Nov 08, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 08, 2023
Browse files
[Misc] Modify `create_dataloader()` (#6537)
parent
382a2de7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
22 deletions
+36
-22
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+36
-22
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
4c6e6543
...
...
@@ -49,7 +49,9 @@ import torchmetrics.functional as MF
from
tqdm
import
tqdm
def
create_dataloader
(
args
,
graph
,
features
,
itemset
,
job
):
def
create_dataloader
(
graph
,
features
,
itemset
,
batch_size
,
fanout
,
device
,
num_workers
,
job
):
"""
[HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks.
...
...
@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job):
Parameters
----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : one of ["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate"
and "infer".
Other parameters are explicated in the comments below.
"""
############################################################################
...
...
@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job):
# gb.ItemSampler()
# [Input]:
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# '
args.
batch_size': Specify the number of samples to be processed together,
# 'batch_size': Specify the number of samples to be processed together,
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.)
...
...
@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
job
==
"train"
)
itemset
,
batch_size
=
batch_size
,
shuffle
=
(
job
==
"train"
)
)
############################################################################
...
...
@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job):
# self.sample_neighbor()
# [Input]:
# 'graph': The network topology for sampling.
# '[-1] or
args.
fanout': Number of neighbors to sample per node. In
# training or validation, the length of
args.
fanout should be equal to the
# '[-1] or fanout': Number of neighbors to sample per node. In
# training or validation, the length of
`
fanout
`
should be equal to the
# number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled.
# [Output]:
...
...
@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
if
job
!=
"infer"
else
[
-
1
]
graph
,
fanout
if
job
!=
"infer"
else
[
-
1
]
)
############################################################################
...
...
@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job):
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe
=
datapipe
.
copy_to
(
device
=
args
.
device
)
datapipe
=
datapipe
.
copy_to
(
device
=
device
)
############################################################################
# [Step-6]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
# '
args.
num_workers': The number of processes to be used for data loading.
# 'num_workers': The number of processes to be used for data loading.
# [Output]:
# A MultiProcessDataLoader object to handle data loading.
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
args
.
num_workers
)
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
num_workers
)
# Return the fully-initialized DataLoader object.
return
dataloader
...
...
@@ -240,7 +233,14 @@ def layerwise_infer(
):
model
.
eval
()
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
all_nodes_set
,
job
=
"infer"
graph
=
graph
,
features
=
features
,
itemset
=
all_nodes_set
,
batch_size
=
4
*
args
.
batch_size
,
fanout
=
[
-
1
],
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"infer"
,
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
...
...
@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y
=
[]
y_hats
=
[]
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
itemset
,
job
=
"evaluate"
graph
=
graph
,
features
=
features
,
itemset
=
itemset
,
batch_size
=
args
.
batch_size
,
fanout
=
args
.
fanout
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"evaluate"
,
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
...
...
@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
def
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
num_classes
,
model
):
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
train_set
,
job
=
"train"
graph
=
graph
,
features
=
features
,
itemset
=
train_set
,
batch_size
=
args
.
batch_size
,
fanout
=
args
.
fanout
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"train"
,
)
for
epoch
in
range
(
args
.
epochs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment