Unverified Commit a3b09f74 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix sparse example. (#7234)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent add12e57
......@@ -36,6 +36,7 @@ import torch.nn.functional as F
import torchmetrics.functional as MF
from dgl.graphbolt.subgraph_sampler import SubgraphSampler
from torch.utils.data import functional_datapipe
from tqdm import tqdm
class SAGEConv(nn.Module):
......@@ -113,7 +114,7 @@ class SparseNeighborSampler(SubgraphSampler):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds, seeds_timestamp=None):
sampled_matrices = []
src = seeds
......@@ -152,7 +153,7 @@ def evaluate(model, dataloader, num_classes):
model.eval()
ys = []
y_hats = []
for it, data in enumerate(dataloader):
for it, data in tqdm(enumerate(dataloader), "Evaluating"):
with torch.no_grad():
node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs
......@@ -194,7 +195,7 @@ def train(device, A, features, dataset, num_classes, model):
for epoch in range(10):
model.train()
total_loss = 0
for it, data in enumerate(train_dataloader):
for it, data in tqdm(enumerate(train_dataloader), "Training"):
node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs
y = data.labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment