Unverified Commit 4c6e6543 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Modify `create_dataloader()` (#6537)

parent 382a2de7
...@@ -49,7 +49,9 @@ import torchmetrics.functional as MF ...@@ -49,7 +49,9 @@ import torchmetrics.functional as MF
from tqdm import tqdm from tqdm import tqdm
def create_dataloader(args, graph, features, itemset, job): def create_dataloader(
graph, features, itemset, batch_size, fanout, device, num_workers, job
):
""" """
[HIGHLIGHT] [HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks. Get a GraphBolt version of a dataloader for node classification tasks.
...@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -59,17 +61,10 @@ def create_dataloader(args, graph, features, itemset, job):
Parameters Parameters
---------- ----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : one of ["train", "evaluate", "infer"] job : one of ["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate" The stage where dataloader is created, with options "train", "evaluate"
and "infer". and "infer".
Other parameters are explicated in the comments below.
""" """
############################################################################ ############################################################################
...@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -77,7 +72,7 @@ def create_dataloader(args, graph, features, itemset, job):
# gb.ItemSampler() # gb.ItemSampler()
# [Input]: # [Input]:
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`) # 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# 'args.batch_size': Specify the number of samples to be processed together, # 'batch_size': Specify the number of samples to be processed together,
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This # indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.) # is in contrast to processing the entire dataset, known as a 'full batch'.)
...@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -91,7 +86,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize the ItemSampler to sample mini-batche from the dataset. # Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################ ############################################################################
datapipe = gb.ItemSampler( datapipe = gb.ItemSampler(
itemset, batch_size=args.batch_size, shuffle=(job == "train") itemset, batch_size=batch_size, shuffle=(job == "train")
) )
############################################################################ ############################################################################
...@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -99,8 +94,8 @@ def create_dataloader(args, graph, features, itemset, job):
# self.sample_neighbor() # self.sample_neighbor()
# [Input]: # [Input]:
# 'graph': The network topology for sampling. # 'graph': The network topology for sampling.
# '[-1] or args.fanout': Number of neighbors to sample per node. In # '[-1] or fanout': Number of neighbors to sample per node. In
# training or validation, the length of args.fanout should be equal to the # training or validation, the length of `fanout` should be equal to the
# number of layers in the model. In inference, this parameter is set to # number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled. # [-1], indicating that all neighbors of a node are sampled.
# [Output]: # [Output]:
...@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -109,7 +104,7 @@ def create_dataloader(args, graph, features, itemset, job):
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor( datapipe = datapipe.sample_neighbor(
graph, args.fanout if job != "infer" else [-1] graph, fanout if job != "infer" else [-1]
) )
############################################################################ ############################################################################
...@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job): ...@@ -148,22 +143,20 @@ def create_dataloader(args, graph, features, itemset, job):
# [Output]: # [Output]:
# A CopyTo object to copy the data to the specified device. # A CopyTo object to copy the data to the specified device.
############################################################################ ############################################################################
datapipe = datapipe.copy_to(device=args.device) datapipe = datapipe.copy_to(device=device)
############################################################################ ############################################################################
# [Step-6]: # [Step-6]:
# gb.MultiProcessDataLoader() # gb.MultiProcessDataLoader()
# [Input]: # [Input]:
# 'datapipe': The datapipe object to be used for data loading. # 'datapipe': The datapipe object to be used for data loading.
# 'args.num_workers': The number of processes to be used for data loading. # 'num_workers': The number of processes to be used for data loading.
# [Output]: # [Output]:
# A MultiProcessDataLoader object to handle data loading. # A MultiProcessDataLoader object to handle data loading.
# [Role]: # [Role]:
# Initialize a multi-process dataloader to load the data in parallel. # Initialize a multi-process dataloader to load the data in parallel.
############################################################################ ############################################################################
dataloader = gb.MultiProcessDataLoader( dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=num_workers)
datapipe, num_workers=args.num_workers
)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
...@@ -240,7 +233,14 @@ def layerwise_infer( ...@@ -240,7 +233,14 @@ def layerwise_infer(
): ):
model.eval() model.eval()
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer" graph=graph,
features=features,
itemset=all_nodes_set,
batch_size=4 * args.batch_size,
fanout=[-1],
device=args.device,
num_workers=args.num_workers,
job="infer",
) )
pred = model.inference(graph, features, dataloader, args.device) pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]] pred = pred[test_set._items[0]]
...@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -260,7 +260,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y = [] y = []
y_hats = [] y_hats = []
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, itemset, job="evaluate" graph=graph,
features=features,
itemset=itemset,
batch_size=args.batch_size,
fanout=args.fanout,
device=args.device,
num_workers=args.num_workers,
job="evaluate",
) )
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
...@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -279,7 +286,14 @@ def evaluate(args, model, graph, features, itemset, num_classes):
def train(args, graph, features, train_set, valid_set, num_classes, model): def train(args, graph, features, train_set, valid_set, num_classes, model):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, train_set, job="train" graph=graph,
features=features,
itemset=train_set,
batch_size=args.batch_size,
fanout=args.fanout,
device=args.device,
num_workers=args.num_workers,
job="train",
) )
for epoch in range(args.epochs): for epoch in range(args.epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment