"third_party/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c454d419cc5e036daaf8ebf73ccb82fa751a5cd0"
Unverified Commit 3e139033 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] RGCN example (#7137)

parent d95720b9
...@@ -114,6 +114,11 @@ def create_dataloader( ...@@ -114,6 +114,11 @@ def create_dataloader(
# Whether to shuffle the items in the dataset before sampling. # Whether to shuffle the items in the dataset before sampling.
datapipe = gb.ItemSampler(item_set, batch_size=batch_size, shuffle=shuffle) datapipe = gb.ItemSampler(item_set, batch_size=batch_size, shuffle=shuffle)
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
# Sample neighbors for each seed node in the mini-batch. # Sample neighbors for each seed node in the mini-batch.
# `graph`: # `graph`:
# The graph(FusedCSCSamplingGraph) from which to sample neighbors. # The graph(FusedCSCSamplingGraph) from which to sample neighbors.
...@@ -133,17 +138,6 @@ def create_dataloader( ...@@ -133,17 +138,6 @@ def create_dataloader(
node_feature_keys["institution"] = ["feat"] node_feature_keys["institution"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys) datapipe = datapipe.fetch_feature(features, node_feature_keys)
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
# [Rui] Usually, we move the mini-batch to target device in the datapipe.
# However, in this example, we leaves the mini-batch on CPU and move it to
# GPU after blocks are created. This is because this example is busy on
# GPU due to embedding layer. And block creation on CPU could be overlapped
# with optimization operation on GPU and it results in better performance.
device = torch.device("cpu")
datapipe = datapipe.copy_to(device)
# Create a DataLoader from the datapipe. # Create a DataLoader from the datapipe.
# `num_workers`: # `num_workers`:
# The number of worker processes to use for data loading. # The number of worker processes to use for data loading.
...@@ -532,7 +526,7 @@ def train( ...@@ -532,7 +526,7 @@ def train(
# Generate predictions. # Generate predictions.
logits = model(blocks, node_features)[category] logits = model(blocks, node_features)[category]
y_hat = logits.log_softmax(dim=-1).cpu() y_hat = logits.log_softmax(dim=-1)
loss = F.nll_loss(y_hat, data.labels[category].long()) loss = F.nll_loss(y_hat, data.labels[category].long())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -559,7 +553,9 @@ def train( ...@@ -559,7 +553,9 @@ def train(
def main(args): def main(args):
device = torch.device("cuda") if args.num_gpus > 0 else torch.device("cpu") device = torch.device(
"cuda" if args.num_gpus > 0 and torch.cuda.is_available() else "cpu"
)
# Load dataset. # Load dataset.
( (
...@@ -571,6 +567,11 @@ def main(args): ...@@ -571,6 +567,11 @@ def main(args):
num_classes, num_classes,
) = load_dataset(args.dataset) ) = load_dataset(args.dataset)
# Move the dataset to the pinned memory to enable GPU access.
if device == torch.device("cuda"):
g.pin_memory_()
features.pin_memory_()
feat_size = features.size("node", "paper", "feat")[0] feat_size = features.size("node", "paper", "feat")[0]
# As `ogb-lsc-mag240m` is a large dataset, features of `author` and # As `ogb-lsc-mag240m` is a large dataset, features of `author` and
...@@ -657,7 +658,7 @@ if __name__ == "__main__": ...@@ -657,7 +658,7 @@ if __name__ == "__main__":
) )
parser.add_argument("--num_epochs", type=int, default=3) parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--num_gpus", type=int, default=0) parser.add_argument("--num_gpus", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment