"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ba63fbdb595f41901f074883abc0084145877cf5"
Unverified Commit a3b09f74 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix sparse example. (#7234)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent add12e57
...@@ -36,6 +36,7 @@ import torch.nn.functional as F ...@@ -36,6 +36,7 @@ import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
from dgl.graphbolt.subgraph_sampler import SubgraphSampler from dgl.graphbolt.subgraph_sampler import SubgraphSampler
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from tqdm import tqdm
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -113,7 +114,7 @@ class SparseNeighborSampler(SubgraphSampler): ...@@ -113,7 +114,7 @@ class SparseNeighborSampler(SubgraphSampler):
fanout = torch.LongTensor([int(fanout)]) fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout) self.fanouts.insert(0, fanout)
def sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds, seeds_timestamp=None):
sampled_matrices = [] sampled_matrices = []
src = seeds src = seeds
...@@ -152,7 +153,7 @@ def evaluate(model, dataloader, num_classes): ...@@ -152,7 +153,7 @@ def evaluate(model, dataloader, num_classes):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
for it, data in enumerate(dataloader): for it, data in tqdm(enumerate(dataloader), "Evaluating"):
with torch.no_grad(): with torch.no_grad():
node_feature = data.node_features["feat"].float() node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs blocks = data.sampled_subgraphs
...@@ -194,7 +195,7 @@ def train(device, A, features, dataset, num_classes, model): ...@@ -194,7 +195,7 @@ def train(device, A, features, dataset, num_classes, model):
for epoch in range(10): for epoch in range(10):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, data in enumerate(train_dataloader): for it, data in tqdm(enumerate(train_dataloader), "Training"):
node_feature = data.node_features["feat"].float() node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs blocks = data.sampled_subgraphs
y = data.labels y = data.labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment