Unverified Commit 4c6e6543 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Modify `create_dataloader()` (#6537)

parent 382a2de7
......@@ -49,7 +49,9 @@ import torchmetrics.functional as MF
from tqdm import tqdm
def create_dataloader(args, graph, features, itemset, job):
def create_dataloader(
graph, features, itemset, batch_size, fanout, device, num_workers, job
):
"""
[HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks.
......@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job):
Parameters
----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : one of ["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate"
and "infer".
Other parameters are explicated in the comments below.
"""
############################################################################
......@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job):
# gb.ItemSampler()
# [Input]:
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# 'args.batch_size': Specify the number of samples to be processed together,
# 'batch_size': Specify the number of samples to be processed together,
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.)
......@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################
datapipe = gb.ItemSampler(
itemset, batch_size=args.batch_size, shuffle=(job == "train")
itemset, batch_size=batch_size, shuffle=(job == "train")
)
############################################################################
......@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job):
# self.sample_neighbor()
# [Input]:
# 'graph': The network topology for sampling.
# '[-1] or args.fanout': Number of neighbors to sample per node. In
# training or validation, the length of args.fanout should be equal to the
# '[-1] or fanout': Number of neighbors to sample per node. In
# training or validation, the length of `fanout` should be equal to the
# number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled.
# [Output]:
......@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
graph, args.fanout if job != "infer" else [-1]
graph, fanout if job != "infer" else [-1]
)
############################################################################
......@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job):
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=args.device)
datapipe = datapipe.copy_to(device=device)
############################################################################
# [Step-6]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
# 'args.num_workers': The number of processes to be used for data loading.
# 'num_workers': The number of processes to be used for data loading.
# [Output]:
# A MultiProcessDataLoader object to handle data loading.
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=args.num_workers
)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=num_workers)
# Return the fully-initialized DataLoader object.
return dataloader
......@@ -240,7 +233,14 @@ def layerwise_infer(
):
model.eval()
dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer"
graph=graph,
features=features,
itemset=all_nodes_set,
batch_size=4 * args.batch_size,
fanout=[-1],
device=args.device,
num_workers=args.num_workers,
job="infer",
)
pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]]
......@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y = []
y_hats = []
dataloader = create_dataloader(
args, graph, features, itemset, job="evaluate"
graph=graph,
features=features,
itemset=itemset,
batch_size=args.batch_size,
fanout=args.fanout,
device=args.device,
num_workers=args.num_workers,
job="evaluate",
)
for step, data in tqdm(enumerate(dataloader)):
......@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
def train(args, graph, features, train_set, valid_set, num_classes, model):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(
args, graph, features, train_set, job="train"
graph=graph,
features=features,
itemset=train_set,
batch_size=args.batch_size,
fanout=args.fanout,
device=args.device,
num_workers=args.num_workers,
job="train",
)
for epoch in range(args.epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment