Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4c6e6543
Unverified
Commit
4c6e6543
authored
Nov 08, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 08, 2023
Browse files
[Misc] Modify `create_dataloader()` (#6537)
parent
382a2de7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
22 deletions
+36
-22
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+36
-22
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
4c6e6543
...
@@ -49,7 +49,9 @@ import torchmetrics.functional as MF
...
@@ -49,7 +49,9 @@ import torchmetrics.functional as MF
from
tqdm
import
tqdm
from
tqdm
import
tqdm
def
create_dataloader
(
args
,
graph
,
features
,
itemset
,
job
):
def
create_dataloader
(
graph
,
features
,
itemset
,
batch_size
,
fanout
,
device
,
num_workers
,
job
):
"""
"""
[HIGHLIGHT]
[HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks.
Get a GraphBolt version of a dataloader for node classification tasks.
...
@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job):
Parameters
Parameters
----------
----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : one of ["train", "evaluate", "infer"]
job : one of ["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate"
The stage where dataloader is created, with options "train", "evaluate"
and "infer".
and "infer".
Other parameters are explicated in the comments below.
"""
"""
############################################################################
############################################################################
...
@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job):
# gb.ItemSampler()
# gb.ItemSampler()
# [Input]:
# [Input]:
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# '
args.
batch_size': Specify the number of samples to be processed together,
# 'batch_size': Specify the number of samples to be processed together,
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This
# indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.)
# is in contrast to processing the entire dataset, known as a 'full batch'.)
...
@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize the ItemSampler to sample mini-batche from the dataset.
# Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################
############################################################################
datapipe
=
gb
.
ItemSampler
(
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
job
==
"train"
)
itemset
,
batch_size
=
batch_size
,
shuffle
=
(
job
==
"train"
)
)
)
############################################################################
############################################################################
...
@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job):
# self.sample_neighbor()
# self.sample_neighbor()
# [Input]:
# [Input]:
# 'graph': The network topology for sampling.
# 'graph': The network topology for sampling.
# '[-1] or
args.
fanout': Number of neighbors to sample per node. In
# '[-1] or fanout': Number of neighbors to sample per node. In
# training or validation, the length of
args.
fanout should be equal to the
# training or validation, the length of
`
fanout
`
should be equal to the
# number of layers in the model. In inference, this parameter is set to
# number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled.
# [-1], indicating that all neighbors of a node are sampled.
# [Output]:
# [Output]:
...
@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
############################################################################
datapipe
=
datapipe
.
sample_neighbor
(
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
if
job
!=
"infer"
else
[
-
1
]
graph
,
fanout
if
job
!=
"infer"
else
[
-
1
]
)
)
############################################################################
############################################################################
...
@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job):
...
@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job):
# [Output]:
# [Output]:
# A CopyTo object to copy the data to the specified device.
# A CopyTo object to copy the data to the specified device.
############################################################################
############################################################################
datapipe
=
datapipe
.
copy_to
(
device
=
args
.
device
)
datapipe
=
datapipe
.
copy_to
(
device
=
device
)
############################################################################
############################################################################
# [Step-6]:
# [Step-6]:
# gb.MultiProcessDataLoader()
# gb.MultiProcessDataLoader()
# [Input]:
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
# 'datapipe': The datapipe object to be used for data loading.
# '
args.
num_workers': The number of processes to be used for data loading.
# 'num_workers': The number of processes to be used for data loading.
# [Output]:
# [Output]:
# A MultiProcessDataLoader object to handle data loading.
# A MultiProcessDataLoader object to handle data loading.
# [Role]:
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
############################################################################
dataloader
=
gb
.
MultiProcessDataLoader
(
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
num_workers
)
datapipe
,
num_workers
=
args
.
num_workers
)
# Return the fully-initialized DataLoader object.
# Return the fully-initialized DataLoader object.
return
dataloader
return
dataloader
...
@@ -240,7 +233,14 @@ def layerwise_infer(
...
@@ -240,7 +233,14 @@ def layerwise_infer(
):
):
model
.
eval
()
model
.
eval
()
dataloader
=
create_dataloader
(
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
all_nodes_set
,
job
=
"infer"
graph
=
graph
,
features
=
features
,
itemset
=
all_nodes_set
,
batch_size
=
4
*
args
.
batch_size
,
fanout
=
[
-
1
],
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"infer"
,
)
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
pred
=
pred
[
test_set
.
_items
[
0
]]
...
@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
...
@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y
=
[]
y
=
[]
y_hats
=
[]
y_hats
=
[]
dataloader
=
create_dataloader
(
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
itemset
,
job
=
"evaluate"
graph
=
graph
,
features
=
features
,
itemset
=
itemset
,
batch_size
=
args
.
batch_size
,
fanout
=
args
.
fanout
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"evaluate"
,
)
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
...
@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
...
@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
def
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
num_classes
,
model
):
def
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
num_classes
,
model
):
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
dataloader
=
create_dataloader
(
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
train_set
,
job
=
"train"
graph
=
graph
,
features
=
features
,
itemset
=
train_set
,
batch_size
=
args
.
batch_size
,
fanout
=
args
.
fanout
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
job
=
"train"
,
)
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment