Unverified Commit 701b4fcc authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampling] New sampling pipeline plus asynchronous prefetching (#3665)

* initial update

* more

* more

* multi-gpu example

* cluster gcn, finalize homogeneous

* more explanation

* fix

* bunch of fixes

* fix

* RGAT example and more fixes

* shadow-gnn sampler and some changes in unit test

* fix

* wth

* more fixes

* remove shadow+node/edge dataloader tests for possible ux changes

* lints

* add legacy dataloading import just in case

* fix

* update pylint for f-strings

* fix

* lint

* lint

* lint again

* cherry-picking commit fa9f494

* oops

* fix

* add sample_neighbors in dist_graph

* fix

* lint

* fix

* fix

* fix

* fix tutorial

* fix

* fix

* fix

* fix warning

* remove debug

* add get_foo_storage apis

* lint
parent 5152a879
...@@ -369,12 +369,6 @@ for epoch in range(1): ...@@ -369,12 +369,6 @@ for epoch in range(1):
# Ultimately, they require the model to predict one scalar score given # Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs. # a node pair among a set of node pairs.
# #
# ``dgl.dataloading.EdgeDataLoader`` allows you to iterate over
# the edges of a new graph with the same nodes, while performing
# neighbor sampling on the original graph with ``g_sampling`` argument.
# This functionality enables convenient evaluation of a link prediction
# model.
#
# Assuming that you have the following test set with labels, where # Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs # ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src`` # with edges in between (or *positive* pairs), and ``test_neg_src``
...@@ -383,10 +377,16 @@ for epoch in range(1): ...@@ -383,10 +377,16 @@ for epoch in range(1):
# #
# Positive pairs # Positive pairs
test_pos_src, test_pos_dst = graph.edges() # These are randomly generated as an example. You will need to
# Negative pairs # replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
torch.randint(0, graph.num_nodes(), (n_test_pos,)),
torch.randint(0, graph.num_nodes(), (n_test_pos,)))
# Negative pairs. Likewise, you will need to replace them with your
# own ground truth.
test_neg_src = test_pos_src test_neg_src = test_pos_src
test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),)) test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
###################################################################### ######################################################################
...@@ -398,10 +398,20 @@ test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),)) ...@@ -398,10 +398,20 @@ test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),))
test_src = torch.cat([test_pos_src, test_pos_dst]) test_src = torch.cat([test_pos_src, test_pos_dst])
test_dst = torch.cat([test_neg_src, test_neg_dst]) test_dst = torch.cat([test_neg_src, test_neg_dst])
test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes()) test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes())
test_graph.edata['label'] = torch.cat( test_ground_truth = torch.cat(
[torch.ones_like(test_pos_src), torch.zeros_like(test_neg_src)]) [torch.ones_like(test_pos_src), torch.zeros_like(test_neg_src)])
######################################################################
# You will need to merge the test graph with the original graph. The
# testing edges' ID will be starting from ``graph.num_edges()``.
#
new_graph = dgl.merge([graph, test_graph])
test_edge_ids = torch.arange(graph.num_edges(), new_graph.num_edges())
###################################################################### ######################################################################
# Then you could create a new ``EdgeDataLoader`` instance that # Then you could create a new ``EdgeDataLoader`` instance that
# iterates on the new ``test_graph``, but uses the original ``graph`` # iterates on the new ``test_graph``, but uses the original ``graph``
...@@ -412,11 +422,11 @@ test_graph.edata['label'] = torch.cat( ...@@ -412,11 +422,11 @@ test_graph.edata['label'] = torch.cat(
# #
test_dataloader = dgl.dataloading.EdgeDataLoader( test_dataloader = dgl.dataloading.EdgeDataLoader(
# The following arguments are specific to EdgeDataLoader. # The following arguments are specific to EdgeDataLoader.
test_graph, # The graph to iterate edges over new_graph, # The graph to iterate edges over
torch.arange(test_graph.number_of_edges()), # The edges to iterate over test_edge_ids, # The edges to iterate over
sampler, # The neighbor sampler sampler, # The neighbor sampler
device=device, # Put the MFGs on CPU or GPU device=device, # Put the MFGs on CPU or GPU
g_sampling=graph, # Graph to sample neighbors exclude=test_edge_ids, # Do not sample test edges as neighbors
# The following arguments are inherited from PyTorch DataLoader. # The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Batch size batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch shuffle=True, # Whether to shuffle the nodes for every epoch
...@@ -447,7 +457,10 @@ with tqdm.tqdm(test_dataloader) as tq, torch.no_grad(): ...@@ -447,7 +457,10 @@ with tqdm.tqdm(test_dataloader) as tq, torch.no_grad():
outputs = model(mfgs, inputs) outputs = model(mfgs, inputs)
test_preds.append(predictor(pair_graph, outputs)) test_preds.append(predictor(pair_graph, outputs))
test_labels.append(pair_graph.edata['label']) test_labels.append(
# Need to map the IDs of test edges in the merged graph back
# to that of test_ground_truth.
test_ground_truth[pair_graph.edata[dgl.EID] - graph.num_edges()])
test_preds = torch.cat(test_preds).cpu().numpy() test_preds = torch.cat(test_preds).cpu().numpy()
test_labels = torch.cat(test_labels).cpu().numpy() test_labels = torch.cat(test_labels).cpu().numpy()
......
...@@ -158,7 +158,6 @@ def run(proc_id, devices): ...@@ -158,7 +158,6 @@ def run(proc_id, devices):
# Copied from previous tutorial with changes highlighted. # Copied from previous tutorial with changes highlighted.
for epoch in range(10): for epoch in range(10):
train_dataloader.set_epoch(epoch) # <--- necessary for dataloader with DDP.
model.train() model.train()
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment