Unverified Commit e6a15c1a authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable `node_classification.py` to run on GPU environment (#6490)

parent 2a92dfca
......@@ -47,7 +47,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from tqdm import tqdm
def create_dataloader(
......@@ -145,6 +145,16 @@ def create_dataloader(
############################################################################
# [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=args.device)
############################################################################
# [Step-6]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
......@@ -191,10 +201,15 @@ class SAGE(nn.Module):
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader):
def inference(self, graph, features, dataloader, device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
......@@ -202,16 +217,21 @@ class SAGE(nn.Module):
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float64,
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.output_nodes[0] : data.output_nodes[-1] + 1] = hidden_x
y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(buffer_device)
feature = y
return y
......@@ -225,7 +245,7 @@ def layerwise_infer(
dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer"
)
pred = model.inference(graph, features, dataloader)
pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)
......@@ -246,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
args, graph, features, itemset, job="evaluate"
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
for step, data in tqdm(enumerate(dataloader)):
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
......@@ -265,10 +285,10 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
args, graph, features, train_set, job="train"
)
for epoch in tqdm.trange(args.epochs):
for epoch in range(args.epochs):
model.train()
total_loss = 0
for step, data in tqdm.tqdm(enumerate(dataloader)):
for step, data in tqdm(enumerate(dataloader)):
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
......@@ -326,14 +346,28 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
return parser.parse_args()
def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.device = torch.device(args.device)
# Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("ogbn-products").load()
graph = dataset.graph
# Currently the neighbor-sampling process can only be done on the CPU,
# therefore there is no need to copy the graph to the GPU.
features = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
......@@ -348,6 +382,8 @@ def main(args):
out_size = num_classes
model = SAGE(in_size, hidden_size, out_size)
assert len(args.fanout) == len(model.layers)
model = model.to(args.device)
# Model training.
print("Training...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment