Unverified Commit df3e7f1b authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Polish sparse example for graphbolt (#6691)

parent 3cdc37cc
...@@ -34,9 +34,8 @@ import torch ...@@ -34,9 +34,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
from dgl.data import AsNodePredDataset from dgl.graphbolt.subgraph_sampler import SubgraphSampler
from ogb.nodeproppred import DglNodePropPredDataset from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterableWrapper
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -102,9 +101,21 @@ class SAGE(nn.Module): ...@@ -102,9 +101,21 @@ class SAGE(nn.Module):
return hidden_x return hidden_x
def multilayer_sample(A, fanouts, minibatch): @functional_datapipe("sample_sparse_neighbor")
class SparseNeighborSampler(SubgraphSampler):
def __init__(self, datapipe, matrix, fanouts):
super().__init__(datapipe)
self.matrix = matrix
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds):
sampled_matrices = [] sampled_matrices = []
src = minibatch.seed_nodes src = seeds
##################################################################### #####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random # (HIGHLIGHT) Using the sparse sample operator to preform random
...@@ -112,18 +123,15 @@ def multilayer_sample(A, fanouts, minibatch): ...@@ -112,18 +123,15 @@ def multilayer_sample(A, fanouts, minibatch):
# compact operator is then employed to compact and relabel the sampled # compact operator is then employed to compact and relabel the sampled
# matrix, resulting in the sampled matrix and the relabel index. # matrix, resulting in the sampled matrix and the relabel index.
##################################################################### #####################################################################
for fanout in self.fanouts:
for fanout in fanouts:
# Sample neighbors. # Sample neighbors.
sampled_matrix = A.sample(1, fanout, ids=src).coalesce() sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
# Compact the sampled matrix. # Compact the sampled matrix.
compacted_mat, row_ids = sampled_matrix.compact(0) compacted_mat, row_ids = sampled_matrix.compact(0)
sampled_matrices.insert(0, compacted_mat) sampled_matrices.insert(0, compacted_mat)
src = row_ids src = row_ids
minibatch.input_nodes = src return src, sampled_matrices
minibatch.sampled_subgraphs = sampled_matrices
return minibatch
############################################################################ ############################################################################
...@@ -132,11 +140,11 @@ def multilayer_sample(A, fanouts, minibatch): ...@@ -132,11 +140,11 @@ def multilayer_sample(A, fanouts, minibatch):
def create_dataloader(A, fanouts, ids, features, device): def create_dataloader(A, fanouts, ids, features, device):
datapipe = gb.ItemSampler(ids, batch_size=1024) datapipe = gb.ItemSampler(ids, batch_size=1024)
# Customize graphbolt sampler by sparse. # Customize graphbolt sampler by sparse.
datapipe = datapipe.map(partial(multilayer_sample, A, fanouts)) datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
# Use grapbolt to fetch features. # Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=4) dataloader = gb.DataLoader(datapipe)
return dataloader return dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment