Unverified Commit 3e139033 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] RGCN example (#7137)

parent d95720b9
......@@ -114,6 +114,11 @@ def create_dataloader(
# Whether to shuffle the items in the dataset before sampling.
datapipe = gb.ItemSampler(item_set, batch_size=batch_size, shuffle=shuffle)
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
# Sample neighbors for each seed node in the mini-batch.
# `graph`:
# The graph(FusedCSCSamplingGraph) from which to sample neighbors.
......@@ -133,17 +138,6 @@ def create_dataloader(
node_feature_keys["institution"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys)
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
# [Rui] Usually, we move the mini-batch to target device in the datapipe.
# However, in this example, we leaves the mini-batch on CPU and move it to
# GPU after blocks are created. This is because this example is busy on
# GPU due to embedding layer. And block creation on CPU could be overlapped
# with optimization operation on GPU and it results in better performance.
device = torch.device("cpu")
datapipe = datapipe.copy_to(device)
# Create a DataLoader from the datapipe.
# `num_workers`:
# The number of worker processes to use for data loading.
......@@ -532,7 +526,7 @@ def train(
# Generate predictions.
logits = model(blocks, node_features)[category]
y_hat = logits.log_softmax(dim=-1).cpu()
y_hat = logits.log_softmax(dim=-1)
loss = F.nll_loss(y_hat, data.labels[category].long())
loss.backward()
optimizer.step()
......@@ -559,7 +553,9 @@ def train(
def main(args):
device = torch.device("cuda") if args.num_gpus > 0 else torch.device("cpu")
device = torch.device(
"cuda" if args.num_gpus > 0 and torch.cuda.is_available() else "cpu"
)
# Load dataset.
(
......@@ -571,6 +567,11 @@ def main(args):
num_classes,
) = load_dataset(args.dataset)
# Move the dataset to the pinned memory to enable GPU access.
if device == torch.device("cuda"):
g.pin_memory_()
features.pin_memory_()
feat_size = features.size("node", "paper", "feat")[0]
# As `ogb-lsc-mag240m` is a large dataset, features of `author` and
......@@ -657,7 +658,7 @@ if __name__ == "__main__":
)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--num_gpus", type=int, default=0)
parser.add_argument("--num_gpus", type=int, default=1)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment