Unverified Commit d20db1ec authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the RGCN_HETERO example (#6060)

parent 566719b1
...@@ -28,6 +28,7 @@ def evaluate(model, loader, node_embed, labels, category, device): ...@@ -28,6 +28,7 @@ def evaluate(model, loader, node_embed, labels, category, device):
total_loss = 0 total_loss = 0
total_acc = 0 total_acc = 0
count = 0 count = 0
with loader.enable_cpu_affinity():
for input_nodes, seeds, blocks in loader: for input_nodes, seeds, blocks in loader:
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] seeds = seeds[category]
...@@ -86,6 +87,12 @@ def main(args): ...@@ -86,6 +87,12 @@ def main(args):
labels = labels.to(device) labels = labels.to(device)
embed_layer = embed_layer.to(device) embed_layer = embed_layer.to(device)
if args.num_workers <= 0:
raise ValueError(
"The '--num_workers' parameter value is expected "
"to be >0, but got {}.".format(args.num_workers)
)
node_embed = embed_layer() node_embed = embed_layer()
# create model # create model
model = EntityClassify( model = EntityClassify(
...@@ -111,7 +118,7 @@ def main(args): ...@@ -111,7 +118,7 @@ def main(args):
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=args.num_workers,
) )
# validation sampler # validation sampler
...@@ -125,7 +132,7 @@ def main(args): ...@@ -125,7 +132,7 @@ def main(args):
val_sampler, val_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=args.num_workers,
) )
# optimizer # optimizer
...@@ -134,13 +141,14 @@ def main(args): ...@@ -134,13 +141,14 @@ def main(args):
# training loop # training loop
print("start training...") print("start training...")
dur = [] mean = 0
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
if epoch > 3: if epoch > 3:
t0 = time.time() t0 = time.time()
with loader.enable_cpu_affinity():
for i, (input_nodes, seeds, blocks) in enumerate(loader): for i, (input_nodes, seeds, blocks) in enumerate(loader):
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
seeds = seeds[ seeds = seeds[
...@@ -157,30 +165,35 @@ def main(args): ...@@ -157,30 +165,35 @@ def main(args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(
print( seeds
"Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".format(
epoch, i, train_acc, loss.item(), time.time() - batch_tic
) )
print(
f"Epoch {epoch:05d} | Batch {i:03d} | Train Acc: "
"{train_acc:.4f} | Train Loss: {loss.item():.4f} | Time: "
"{time.time() - batch_tic:.4f}"
) )
if epoch > 3: if epoch > 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
val_loss, val_acc = evaluate( val_loss, val_acc = evaluate(
model, val_loader, node_embed, labels, category, device model, val_loader, node_embed, labels, category, device
) )
print( print(
"Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format( f"Epoch {epoch:05d} | Valid Acc: {val_acc:.4f} | Valid loss: "
epoch, val_acc, val_loss, np.average(dur) "{val_loss:.4f} | Time: {mean:.4f}"
)
) )
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
output = model.inference( output = model.inference(
g, args.batch_size, "cuda" if use_cuda else "cpu", 0, node_embed g,
args.batch_size,
"cuda" if use_cuda else "cpu",
args.num_workers,
node_embed,
) )
test_pred = output[category][test_idx] test_pred = output[category][test_idx]
test_labels = labels[test_idx].to(test_pred.device) test_labels = labels[test_idx].to(test_pred.device)
...@@ -245,6 +258,10 @@ if __name__ == "__main__": ...@@ -245,6 +258,10 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. " "be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.", "This flag disables that.",
) )
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of node dataloader"
)
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument("--validation", dest="validation", action="store_true") fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument("--testing", dest="validation", action="store_false") fp.add_argument("--testing", dest="validation", action="store_false")
......
...@@ -423,6 +423,7 @@ class EntityClassify(nn.Module): ...@@ -423,6 +423,7 @@ class EntityClassify(nn.Module):
num_workers=num_workers, num_workers=num_workers,
) )
with dataloader.enable_cpu_affinity():
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment