Unverified Commit 3e26c3d1 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

fix caregnn (#4211)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 52d43127
......@@ -92,9 +92,10 @@ def main(args):
graph.ndata['nd'] = th.tanh(model.layers[i].MLP(layers_feat[i]))
for etype in graph.canonical_etypes:
graph.apply_edges(_l1_dist, etype=etype)
dist[etype] = graph.edges[etype].data['ed']
dist[etype] = graph.edges[etype].data.pop('ed').detach().cpu()
dists.append(dist)
p.append(model.layers[i].p)
graph.ndata.pop('nd')
sampler = CARESampler(p, dists, args.num_layers)
# train
......@@ -103,14 +104,9 @@ def main(args):
tr_recall = 0
tr_auc = 0
tr_blk = 0
train_dataloader = dgl.dataloading.DataLoader(graph,
train_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, batch_size=args.batch_size,
shuffle=True, drop_last=False, num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in train_dataloader:
blocks = [b.to(device) for b in blocks]
......@@ -135,14 +131,9 @@ def main(args):
# validation
model.eval()
val_dataloader = dgl.dataloading.DataLoader(graph,
val_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
val_dataloader = dgl.dataloading.DataLoader(
graph, val_idx, sampler, batch_size=args.batch_size,
shuffle=True, drop_last=False, num_workers=args.num_workers)
val_recall, val_auc, val_loss = evaluate(model, loss_fn, val_dataloader, device)
......@@ -159,14 +150,9 @@ def main(args):
model.eval()
if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt'))
test_dataloader = dgl.dataloading.DataLoader(graph,
test_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
test_dataloader = dgl.dataloading.DataLoader(
graph, test_idx, sampler, batch_size=args.batch_size,
shuffle=True, drop_last=False, num_workers=args.num_workers)
test_recall, test_auc, test_loss = evaluate(model, loss_fn, test_dataloader, device)
......
......@@ -13,9 +13,10 @@ def _l1_dist(edges):
class CARESampler(dgl.dataloading.BlockSampler):
def __init__(self, p, dists, num_layers):
super().__init__(num_layers)
super().__init__()
self.p = p
self.dists = dists
self.num_layers = num_layers
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
with g.local_scope():
......@@ -28,7 +29,7 @@ class CARESampler(dgl.dataloading.BlockSampler):
num_neigh = th.ceil(g.in_degrees(node, etype=etype) * self.p[block_id][etype]).int().item()
neigh_dist = self.dists[block_id][etype][edges]
if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
neigh_index = np.argpartition(neigh_dist, num_neigh)[:num_neigh]
else:
neigh_index = np.arange(num_neigh)
edge_mask[edges[neigh_index]] = 1
......@@ -36,6 +37,19 @@ class CARESampler(dgl.dataloading.BlockSampler):
return dgl.edge_subgraph(g, new_edges_masks, relabel_nodes=False)
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes
blocks = []
for block_id in reversed(range(self.num_layers)):
frontier = self.sample_frontier(block_id, g, seed_nodes)
eid = frontier.edata[dgl.EID]
block = dgl.to_block(frontier, seed_nodes)
block.edata[dgl.EID] = eid
seed_nodes = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seed_nodes, output_nodes, blocks
def __len__(self):
return self.num_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment