Unverified Commit e682fa74 authored by Chendi.Xue's avatar Chendi.Xue Committed by GitHub
Browse files

[EXAMPLE]Add device support in hetero_rgcn (#4677)



* Add device support in hetero_rgcn
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>

* use num_workers instead of hard code and enable cpu_affinity for pytorch > 1.12
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>

* Remove hard code num_workers and use dgl version to check if cpu_affinity is supported
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>

* Remove specified dl_cores and computer_cores

Add error print if num_workers are miss set
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>

* expected num_workers should be less than num_physical_cores
Signed-off-by: default avatarChendi Xue <chendi.xue@intel.com>

* Update examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
Co-authored-by: default avatarChang Liu <chang.liu@utexas.edu>

* Remove dgl version and num_workers print
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>

* add comment to explain is_support_affinity

* Fix typo
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>
Signed-off-by: default avatarXue, Chendi <chendi.xue@intel.com>
Signed-off-by: default avatarChendi Xue <chendi.xue@intel.com>
Co-authored-by: default avatarChang Liu <chang.liu@utexas.edu>
parent 9b62e8d0
......@@ -12,8 +12,12 @@ import dgl.nn as dglnn
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding
import psutil
import sys
def prepare_data(args):
v_t = dgl.__version__
def prepare_data(args, device):
dataset = DglNodePropPredDataset(name="ogbn-mag")
split_idx = dataset.get_idx_split()
# graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
......@@ -29,13 +33,15 @@ def prepare_data(args):
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
num_workers = args.num_workers
train_loader = dgl.dataloading.DataLoader(
g,
split_idx["train"],
sampler,
batch_size=1024,
shuffle=True,
num_workers=0,
num_workers=num_workers,
device=device
)
return g, labels, dataset.num_classes, split_idx, logger, train_loader
......@@ -314,6 +320,7 @@ def test(g, model, node_embed, y_true, device, split_idx):
batch_size=16384,
shuffle=False,
num_workers=0,
device=device
)
pbar = tqdm(total=y_true.size(0))
......@@ -365,13 +372,16 @@ def test(g, model, node_embed, y_true, device, split_idx):
return train_acc, valid_acc, test_acc
def is_support_affinity(v_t):
# dgl supports enable_cpu_affinity since 0.9.1
return v_t >= "0.9.1"
def main(args):
device = f"cuda:0" if th.cuda.is_available() else "cpu"
g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args)
g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args, device)
embed_layer = rel_graph_embed(g, 128)
embed_layer = rel_graph_embed(g, 128).to(device)
model = EntityClassify(g, 128, num_classes).to(device)
print(
......@@ -383,8 +393,12 @@ def main(args):
for run in range(args.runs):
embed_layer.reset_parameters()
model.reset_parameters()
try:
embed_layer.reset_parameters()
model.reset_parameters()
except:
# old pytorch version doesn't support reset_parameters() API
pass
# optimizer
all_params = itertools.chain(
......@@ -392,18 +406,36 @@ def main(args):
)
optimizer = th.optim.Adam(all_params, lr=0.01)
logger = train(
g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
if args.num_workers != 0 and device == "cpu" and is_support_affinity(v_t):
expected_max = int(psutil.cpu_count(logical=False))
if args.num_workers >= expected_max:
print(f"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}", file=sys.stderr)
with train_loader.enable_cpu_affinity():
logger = train(
g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
else:
logger = train(
g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
logger.print_statistics(run)
print("Final performance: ")
......@@ -413,6 +445,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--num_workers", type=int, default=0)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment