Unverified Commit df3e7f1b authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Polish sparse example for graphbolt (#6691)

parent 3cdc37cc
......@@ -34,9 +34,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
from dgl.data import AsNodePredDataset
from ogb.nodeproppred import DglNodePropPredDataset
from torchdata.datapipes.iter import IterableWrapper
from dgl.graphbolt.subgraph_sampler import SubgraphSampler
from torch.utils.data import functional_datapipe
class SAGEConv(nn.Module):
......@@ -102,9 +101,21 @@ class SAGE(nn.Module):
return hidden_x
def multilayer_sample(A, fanouts, minibatch):
@functional_datapipe("sample_sparse_neighbor")
class SparseNeighborSampler(SubgraphSampler):
def __init__(self, datapipe, matrix, fanouts):
super().__init__(datapipe)
self.matrix = matrix
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds):
sampled_matrices = []
src = minibatch.seed_nodes
src = seeds
#####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random
......@@ -112,18 +123,15 @@ def multilayer_sample(A, fanouts, minibatch):
# compact operator is then employed to compact and relabel the sampled
# matrix, resulting in the sampled matrix and the relabel index.
#####################################################################
for fanout in fanouts:
for fanout in self.fanouts:
# Sample neighbors.
sampled_matrix = A.sample(1, fanout, ids=src).coalesce()
sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
# Compact the sampled matrix.
compacted_mat, row_ids = sampled_matrix.compact(0)
sampled_matrices.insert(0, compacted_mat)
src = row_ids
minibatch.input_nodes = src
minibatch.sampled_subgraphs = sampled_matrices
return minibatch
return src, sampled_matrices
############################################################################
......@@ -132,11 +140,11 @@ def multilayer_sample(A, fanouts, minibatch):
def create_dataloader(A, fanouts, ids, features, device):
datapipe = gb.ItemSampler(ids, batch_size=1024)
# Customize graphbolt sampler by sparse.
datapipe = datapipe.map(partial(multilayer_sample, A, fanouts))
datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
# Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=4)
dataloader = gb.DataLoader(datapipe)
return dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment