Unverified Commit d20db1ec authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the RGCN_HETERO example (#6060)

parent 566719b1
...@@ -28,18 +28,19 @@ def evaluate(model, loader, node_embed, labels, category, device): ...@@ -28,18 +28,19 @@ def evaluate(model, loader, node_embed, labels, category, device):
total_loss = 0 total_loss = 0
total_acc = 0 total_acc = 0
count = 0 count = 0
for input_nodes, seeds, blocks in loader: with loader.enable_cpu_affinity():
blocks = [blk.to(device) for blk in blocks] for input_nodes, seeds, blocks in loader:
seeds = seeds[category] blocks = [blk.to(device) for blk in blocks]
emb = extract_embed(node_embed, input_nodes) seeds = seeds[category]
emb = {k: e.to(device) for k, e in emb.items()} emb = extract_embed(node_embed, input_nodes)
lbl = labels[seeds].to(device) emb = {k: e.to(device) for k, e in emb.items()}
logits = model(emb, blocks)[category] lbl = labels[seeds].to(device)
loss = F.cross_entropy(logits, lbl) logits = model(emb, blocks)[category]
acc = th.sum(logits.argmax(dim=1) == lbl).item() loss = F.cross_entropy(logits, lbl)
total_loss += loss.item() * len(seeds) acc = th.sum(logits.argmax(dim=1) == lbl).item()
total_acc += acc total_loss += loss.item() * len(seeds)
count += len(seeds) total_acc += acc
count += len(seeds)
return total_loss / count, total_acc / count return total_loss / count, total_acc / count
...@@ -86,6 +87,12 @@ def main(args): ...@@ -86,6 +87,12 @@ def main(args):
labels = labels.to(device) labels = labels.to(device)
embed_layer = embed_layer.to(device) embed_layer = embed_layer.to(device)
if args.num_workers <= 0:
raise ValueError(
"The '--num_workers' parameter value is expected "
"to be >0, but got {}.".format(args.num_workers)
)
node_embed = embed_layer() node_embed = embed_layer()
# create model # create model
model = EntityClassify( model = EntityClassify(
...@@ -111,7 +118,7 @@ def main(args): ...@@ -111,7 +118,7 @@ def main(args):
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=args.num_workers,
) )
# validation sampler # validation sampler
...@@ -125,7 +132,7 @@ def main(args): ...@@ -125,7 +132,7 @@ def main(args):
val_sampler, val_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=args.num_workers,
) )
# optimizer # optimizer
...@@ -134,53 +141,59 @@ def main(args): ...@@ -134,53 +141,59 @@ def main(args):
# training loop # training loop
print("start training...") print("start training...")
dur = [] mean = 0
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
if epoch > 3: if epoch > 3:
t0 = time.time() t0 = time.time()
for i, (input_nodes, seeds, blocks) in enumerate(loader): with loader.enable_cpu_affinity():
blocks = [blk.to(device) for blk in blocks] for i, (input_nodes, seeds, blocks) in enumerate(loader):
seeds = seeds[ blocks = [blk.to(device) for blk in blocks]
category seeds = seeds[
] # we only predict the nodes with type "category" category
batch_tic = time.time() ] # we only predict the nodes with type "category"
emb = extract_embed(node_embed, input_nodes) batch_tic = time.time()
lbl = labels[seeds] emb = extract_embed(node_embed, input_nodes)
if use_cuda: lbl = labels[seeds]
emb = {k: e.cuda() for k, e in emb.items()} if use_cuda:
lbl = lbl.cuda() emb = {k: e.cuda() for k, e in emb.items()}
logits = model(emb, blocks)[category] lbl = lbl.cuda()
loss = F.cross_entropy(logits, lbl) logits = model(emb, blocks)[category]
loss.backward() loss = F.cross_entropy(logits, lbl)
optimizer.step() loss.backward()
optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(
print( seeds
"Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".format( )
epoch, i, train_acc, loss.item(), time.time() - batch_tic print(
f"Epoch {epoch:05d} | Batch {i:03d} | Train Acc: "
"{train_acc:.4f} | Train Loss: {loss.item():.4f} | Time: "
"{time.time() - batch_tic:.4f}"
) )
)
if epoch > 3: if epoch > 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
val_loss, val_acc = evaluate( val_loss, val_acc = evaluate(
model, val_loader, node_embed, labels, category, device model, val_loader, node_embed, labels, category, device
) )
print( print(
"Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format( f"Epoch {epoch:05d} | Valid Acc: {val_acc:.4f} | Valid loss: "
epoch, val_acc, val_loss, np.average(dur) "{val_loss:.4f} | Time: {mean:.4f}"
) )
)
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
output = model.inference( output = model.inference(
g, args.batch_size, "cuda" if use_cuda else "cpu", 0, node_embed g,
args.batch_size,
"cuda" if use_cuda else "cpu",
args.num_workers,
node_embed,
) )
test_pred = output[category][test_idx] test_pred = output[category][test_idx]
test_labels = labels[test_idx].to(test_pred.device) test_labels = labels[test_idx].to(test_pred.device)
...@@ -245,6 +258,10 @@ if __name__ == "__main__": ...@@ -245,6 +258,10 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. " "be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.", "This flag disables that.",
) )
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of node dataloader"
)
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument("--validation", dest="validation", action="store_true") fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument("--testing", dest="validation", action="store_false") fp.add_argument("--testing", dest="validation", action="store_false")
......
...@@ -423,17 +423,18 @@ class EntityClassify(nn.Module): ...@@ -423,17 +423,18 @@ class EntityClassify(nn.Module):
num_workers=num_workers, num_workers=num_workers,
) )
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): with dataloader.enable_cpu_affinity():
block = blocks[0].to(device) for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = {
k: x[k][input_nodes[k]].to(device) h = {
for k in input_nodes.keys() k: x[k][input_nodes[k]].to(device)
} for k in input_nodes.keys()
h = layer(block, h) }
h = layer(block, h)
for k in output_nodes.keys():
y[k][output_nodes[k]] = h[k].cpu() for k in output_nodes.keys():
y[k][output_nodes[k]] = h[k].cpu()
x = y x = y
return y return y
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment